diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index e63002e84..034ce309b 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -2,8 +2,7 @@ name: Bug report about: Create a report to help us improve title: '' -labels: bug -assignees: lepture +type: 'Bug' --- @@ -13,7 +12,7 @@ A clear and concise description of what the bug is. **Error Stacks** -``` +```python put error stacks here ``` diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index f0291e059..e947976c6 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -4,6 +4,7 @@ about: Suggest an idea for this project title: '' labels: '' assignees: '' +type: 'Feature' --- diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 14b2290f3..c3331ecfe 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,22 +1,35 @@ + -**What kind of change does this PR introduce?** (check at least one) +**What kind of change does this PR introduce?** -- [ ] Bugfix -- [ ] Feature -- [ ] Code style update -- [ ] Refactor -- [ ] Other, please describe: + -- [ ] Yes -- [ ] No +**Does this PR introduce a breaking change?** -If yes, please describe the impact and migration path for existing applications: + -(If no, please delete the above question and this text message.) +**Checklist** + +- [ ] The commits follow the [conventional commits](https://www.conventionalcommits.org) specification. +- [ ] You ran the linters with ``prek``. +- [ ] You wrote unit test to demonstrate the bug you are fixing, or to stress the feature you are bringing. +- [ ] You reached 100% of code coverage on the code you edited, without abusive use of `pragma: no cover` +- [ ] If this PR is about a new feature, or a behavior change, you have updated the documentation accordingly. --- diff --git a/.github/SECURITY.md b/.github/SECURITY.md new file mode 100644 index 000000000..c714fb0d5 --- /dev/null +++ b/.github/SECURITY.md @@ -0,0 +1,17 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| 1.1.x | :white_check_mark: | +| 0.15.x | :white_check_mark: | +| < 0.15 | :x: | + +## Reporting a Vulnerability + +If you found security bugs, please **do not send a public issue or patch**. +You can send me email at . + +Or, you can use the [Tidelift security contact](https://tidelift.com/security). +Tidelift will coordinate the fix and disclosure. diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 000000000..004997706 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,41 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v4 + with: + languages: python + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v4 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000..3be4d3b19 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,25 @@ +name: docs + +on: + push: + branches-ignore: + - 'wip-*' + pull_request: + branches-ignore: + - 'wip-*' + +env: + FORCE_COLOR: '1' + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + enable-cache: true + - run: | + uv sync --all-groups + uv run sphinx-build docs build/sphinx/html --fail-on-warning diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml new file mode 100644 index 000000000..bd93c73fb --- /dev/null +++ b/.github/workflows/pypi.yml @@ -0,0 +1,56 @@ +name: Release to PyPI + +permissions: + contents: write + id-token: write + +on: + push: + tags: + - "v1.*" + +env: + FORCE_COLOR: '1' + +jobs: + build: + name: build dist files + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-python@v6 + with: + python-version: 3.14 + + - name: install build + run: python -m pip install --upgrade build + + - name: build dist + run: python -m build + + - uses: actions/upload-artifact@v4 + with: + name: artifacts + path: dist/* + if-no-files-found: error + + publish: + environment: + name: pypi-release + url: https://pypi.org/project/Authlib/ + permissions: + id-token: write + name: release to pypi + needs: build + runs-on: ubuntu-latest + + steps: + - uses: actions/download-artifact@v4 + with: + name: artifacts + path: dist + + - name: Push build artifacts to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 4db46eb98..43ce02186 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -6,12 +6,16 @@ on: - 'wip-*' paths-ignore: - 'docs/**' + - 'README.md' pull_request: branches-ignore: - 'wip-*' paths-ignore: - 'docs/**' +env: + FORCE_COLOR: '1' + jobs: build: @@ -21,43 +25,65 @@ jobs: max-parallel: 3 matrix: python: - - version: 2.7 - toxenv: py27,py27-flask - - version: 3.6 - toxenv: py36,flask,django,py3 - - version: 3.7 - toxenv: py37,flask,django,py3 - - version: 3.8 - toxenv: py38,flask,django,py3 + - version: "3.10" + - version: "3.11" + - version: "3.12" + - version: "3.13" + - version: "3.14" + - version: "pypy@3.11" steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python.version }} - uses: actions/setup-python@v1.1.1 + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Install uv + uses: astral-sh/setup-uv@v7 with: - python-version: ${{ matrix.python.version }} + enable-cache: true + cache-dependency-glob: | + **/uv.lock + + - name: Set up Python ${{ matrix.python.version }} + run: | + uv python install ${{ matrix.python.version }} + uv python pin ${{ matrix.python.version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install tox - pip install -r requirements-test.txt + uv sync - - name: Test with tox ${{ matrix.python.toxenv }} + - name: Test with tox env: - TOXENV: ${{ matrix.python.toxenv }} - run: tox + TOXENV: py,jose,clients,flask,django + run: | + uvx --with tox-uv tox -p auto - name: Report coverage run: | - coverage combine - coverage report - coverage xml + uv run coverage combine + uv run coverage report + uv run coverage xml + + - name: Check diff coverage for modified files + if: github.event_name == 'pull_request' + run: | + uv run diff-cover coverage.xml --compare-branch=origin/${{ github.base_ref }} --fail-under=100 --format github-annotations:warning - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1.0.5 + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.xml + files: ./coverage.xml flags: unittests name: GitHub + + - name: SonarCloud Scan + uses: SonarSource/sonarqube-scan-action@v6 + continue-on-error: true + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + + - name: Minimize cache + run: | + uv cache prune --ci diff --git a/.gitignore b/.gitignore index b0bcd0b13..8ec5f8772 100644 --- a/.gitignore +++ b/.gitignore @@ -12,9 +12,13 @@ parts .installed.cfg docs/_build htmlcov/ +.venv/ venv/ .tox .coverage* .pytest_cache/ *.egg .idea/ +uv.lock +.env +coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..d522ea9b9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +--- +default_install_hook_types: + - pre-commit + - commit-msg +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.8 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format + - repo: https://github.com/codespell-project/codespell + rev: v2.4.2 + hooks: + - id: codespell + stages: [pre-commit] + additional_dependencies: + - tomli + exclude: "docs/locales" + args: [--write-changes] + - repo: https://github.com/compilerla/conventional-pre-commit + rev: v4.4.0 + hooks: + - id: conventional-pre-commit + stages: [commit-msg] + args: [ + "--verbose", + "--scope", + "jose,oauth,oidc,client", + ] diff --git a/.py27conf b/.py27conf deleted file mode 100644 index d5da6ab0c..000000000 --- a/.py27conf +++ /dev/null @@ -1,16 +0,0 @@ -[coverage:run] -branch = True -omit = - authlib/integrations/base_client/async_app.py - authlib/integrations/httpx_client/* - authlib/integrations/starlette_client/* - - -[coverage:report] -exclude_lines = - pragma: no cover - except ImportError - def __repr__ - raise NotImplementedError - raise DeprecationWarning - deprecate diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..0432a0f62 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +--- +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.14" + jobs: + post_create_environment: + - pip install uv + - uv export --group docs --group clients --group flask --no-hashes --output-file requirements.txt + post_install: + - pip install . + - pip install --requirement requirements.txt + +sphinx: + configuration: docs/conf.py diff --git a/BACKERS.md b/BACKERS.md index bbf1ad323..fdc247447 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -13,16 +13,43 @@ Many thanks to these awesome sponsors and backers. -For quickly implementing token-based authencation, feel free to check Authing's Python SDK. +For quickly implementing token-based authentication, feel free to check Authing's Python SDK. + + + +Kraken is the world's leading customer & culture platform for energy, water & broadband. Licensing enquiries at Kraken.tech. + + + + + + + +
+ +Sentry +
+Sentry +
+ +Indeed +
+Indeed +
+ +Around +
+Around +
## Awesome Backers - + + + + + + +
+ Aveline
@@ -36,9 +63,51 @@ Aveline
-Callam +Callam
Callam
+ +Krishna Kumar +
+Krishna Kumar +
+ +Junnplus +
+Jun +
+ +Malik Piara +
+Malik Piara +
+ +Alan +
+Alan +
+ +Alan +
+Jeff Heaton +
+ +Alan +
+Birk Jernström +
+ +Yaal Coop +
+Yaal Coop +
diff --git a/Makefile b/Makefile index 617a66e22..936f6d218 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,7 @@ -.PHONY: tests clean clean-pyc clean-build docs +.PHONY: tests clean clean-pyc clean-build docs build + +build: + @python3 -m build clean: clean-build clean-pyc clean-docs clean-tox @@ -24,5 +27,5 @@ clean-docs: clean-tox: @rm -rf .tox/ -docs: - @$(MAKE) -C docs html +build-docs: + @sphinx-build docs build/_html -a diff --git a/README.md b/README.md index 3c6e54f27..3d10fb979 100644 --- a/README.md +++ b/README.md @@ -1,97 +1,109 @@ - - - +
-# Authlib + + + Authlib + - -Build Status -Coverage Status -PyPI Version -Maintainability -Follow Twitter +[![Build Status](https://github.com/authlib/authlib/workflows/tests/badge.svg)](https://github.com/authlib/authlib/actions) +[![PyPI version](https://img.shields.io/pypi/v/authlib.svg)](https://pypi.org/project/authlib) +[![conda-forge version](https://img.shields.io/conda/v/conda-forge/authlib.svg?label=conda-forge&colorB=0090ff)](https://anaconda.org/conda-forge/authlib) +[![PyPI Downloads](https://static.pepy.tech/badge/authlib/month)](https://pepy.tech/projects/authlib) +[![Code Coverage](https://codecov.io/gh/authlib/authlib/graph/badge.svg?token=OWTdxAIsPI)](https://codecov.io/gh/authlib/authlib) +[![Maintainability Rating](https://sonarcloud.io/api/project_badges/measure?project=authlib_authlib&metric=sqale_rating)](https://sonarcloud.io/summary/new_code?id=authlib_authlib) + +
The ultimate Python library in building OAuth and OpenID Connect servers. JWS, JWK, JWA, JWT are included. -Authlib is compatible with Python2.7+ and Python3.6+. +Authlib is compatible with Python3.10+. + +## Migrations + +Authlib will deprecate `authlib.jose` module, please read: + +- [Migrating from `authlib.jose` to `joserfc`](https://jose.authlib.org/en/dev/migrations/authlib/) + +## Sponsors + + + + + + + + + + +
If you want to quickly add secure token-based authentication to Python projects, feel free to check Auth0's Python SDK and free plan at auth0.com/overview.
A blogging and podcast hosting platform with minimal design but powerful features. Host your blog and Podcast with Typlog.com. +
+ +[**Fund Authlib to access additional features**](https://docs.authlib.org/en/stable/community/funding.html) ## Features Generic, spec-compliant implementation to build clients and providers: -- [The OAuth 1.0 Protocol](https://docs.authlib.org/en/latest/basic/oauth1.html) - - [RFC5849: The OAuth 1.0 Protocol](https://docs.authlib.org/en/latest/specs/rfc5849.html) -- [The OAuth 2.0 Authorization Framework](https://docs.authlib.org/en/latest/basic/oauth2.html) - - [RFC6749: The OAuth 2.0 Authorization Framework](https://docs.authlib.org/en/latest/specs/rfc6749.html) - - [RFC6750: The OAuth 2.0 Authorization Framework: Bearer Token Usage](https://docs.authlib.org/en/latest/specs/rfc6750.html) - - [RFC7009: OAuth 2.0 Token Revocation](https://docs.authlib.org/en/latest/specs/rfc7009.html) - - [RFC7591: OAuth 2.0 Dynamic Client Registration Protocol](https://docs.authlib.org/en/latest/specs/rfc7591.html) - - [ ] RFC7592: OAuth 2.0 Dynamic Client Registration Management Protocol - - [RFC7636: Proof Key for Code Exchange by OAuth Public Clients](https://docs.authlib.org/en/latest/specs/rfc7636.html) - - [RFC7662: OAuth 2.0 Token Introspection](https://docs.authlib.org/en/latest/specs/rfc7662.html) - - [RFC8414: OAuth 2.0 Authorization Server Metadata](https://docs.authlib.org/en/latest/specs/rfc8414.html) - - [RFC8628: OAuth 2.0 Device Authorization Grant](https://docs.authlib.org/en/latest/specs/rfc8628.html) -- [Javascript Object Signing and Encryption](https://docs.authlib.org/en/latest/jose/index.html) - - [RFC7515: JSON Web Signature](https://docs.authlib.org/en/latest/jose/jws.html) - - [RFC7516: JSON Web Encryption](https://docs.authlib.org/en/latest/jose/jwe.html) - - [RFC7517: JSON Web Key](https://docs.authlib.org/en/latest/jose/jwk.html) - - [RFC7518: JSON Web Algorithms](https://docs.authlib.org/en/latest/specs/rfc7518.html) - - [RFC7519: JSON Web Token](https://docs.authlib.org/en/latest/jose/jwt.html) - - [RFC7638: JSON Web Key (JWK) Thumbprint](https://docs.authlib.org/en/latest/specs/rfc7638.html) +- [The OAuth 1.0 Protocol](https://docs.authlib.org/en/stable/basic/oauth1.html) + - [RFC5849: The OAuth 1.0 Protocol](https://docs.authlib.org/en/stable/specs/rfc5849.html) +- [The OAuth 2.0 Authorization Framework](https://docs.authlib.org/en/stable/basic/oauth2.html) + - [RFC6749: The OAuth 2.0 Authorization Framework](https://docs.authlib.org/en/stable/specs/rfc6749.html) + - [RFC6750: The OAuth 2.0 Authorization Framework: Bearer Token Usage](https://docs.authlib.org/en/stable/specs/rfc6750.html) + - [RFC7009: OAuth 2.0 Token Revocation](https://docs.authlib.org/en/stable/specs/rfc7009.html) + - [RFC7523: JWT Profile for OAuth 2.0 Client Authentication and Authorization Grants](https://docs.authlib.org/en/stable/specs/rfc7523.html) + - [RFC7591: OAuth 2.0 Dynamic Client Registration Protocol](https://docs.authlib.org/en/stable/specs/rfc7591.html) + - [RFC7592: OAuth 2.0 Dynamic Client Registration Management Protocol](https://docs.authlib.org/en/stable/specs/rfc7592.html) + - [RFC7636: Proof Key for Code Exchange by OAuth Public Clients](https://docs.authlib.org/en/stable/specs/rfc7636.html) + - [RFC7662: OAuth 2.0 Token Introspection](https://docs.authlib.org/en/stable/specs/rfc7662.html) + - [RFC8414: OAuth 2.0 Authorization Server Metadata](https://docs.authlib.org/en/stable/specs/rfc8414.html) + - [RFC8628: OAuth 2.0 Device Authorization Grant](https://docs.authlib.org/en/stable/specs/rfc8628.html) + - [RFC9068: JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens](https://docs.authlib.org/en/stable/specs/rfc9068.html) + - [RFC9101: The OAuth 2.0 Authorization Framework: JWT-Secured Authorization Request (JAR)](https://docs.authlib.org/en/stable/specs/rfc9101.html) + - [RFC9207: OAuth 2.0 Authorization Server Issuer Identification](https://docs.authlib.org/en/stable/specs/rfc9207.html) +- [Javascript Object Signing and Encryption](https://docs.authlib.org/en/stable/jose/index.html) + - [RFC7515: JSON Web Signature](https://docs.authlib.org/en/stable/jose/jws.html) + - [RFC7516: JSON Web Encryption](https://docs.authlib.org/en/stable/jose/jwe.html) + - [RFC7517: JSON Web Key](https://docs.authlib.org/en/stable/jose/jwk.html) + - [RFC7518: JSON Web Algorithms](https://docs.authlib.org/en/stable/specs/rfc7518.html) + - [RFC7519: JSON Web Token](https://docs.authlib.org/en/stable/jose/jwt.html) + - [RFC7638: JSON Web Key (JWK) Thumbprint](https://docs.authlib.org/en/stable/specs/rfc7638.html) - [ ] RFC7797: JSON Web Signature (JWS) Unencoded Payload Option - - [RFC8037: ECDH in JWS and JWE](https://docs.authlib.org/en/latest/specs/rfc8037.html) -- [OpenID Connect 1.0](https://docs.authlib.org/en/latest/specs/oidc.html) + - [RFC8037: ECDH in JWS and JWE](https://docs.authlib.org/en/stable/specs/rfc8037.html) + - [ ] draft-madden-jose-ecdh-1pu-04: Public Key Authenticated Encryption for JOSE: ECDH-1PU +- [OpenID Connect 1.0](https://docs.authlib.org/en/stable/specs/oidc.html) - [x] OpenID Connect Core 1.0 - [x] OpenID Connect Discovery 1.0 + - [x] OpenID Connect Dynamic Client Registration 1.0 + - [x] [OpenID Connect RP-Initiated Logout 1.0](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) Connect third party OAuth providers with Authlib built-in client integrations: - Requests - - [OAuth1Session](https://docs.authlib.org/en/latest/client/requests.html#requests-oauth-1-0) - - [OAuth2Session](https://docs.authlib.org/en/latest/client/requests.html#requests-oauth-2-0) - - [OpenID Connect](https://docs.authlib.org/en/latest/client/requests.html#requests-openid-connect) - - [AssertionSession](https://docs.authlib.org/en/latest/client/requests.html#requests-service-account) + - [OAuth1Session](https://docs.authlib.org/en/stable/client/requests.html#requests-oauth-1-0) + - [OAuth2Session](https://docs.authlib.org/en/stable/client/requests.html#requests-oauth-2-0) + - [OpenID Connect](https://docs.authlib.org/en/stable/client/requests.html#requests-openid-connect) + - [AssertionSession](https://docs.authlib.org/en/stable/client/requests.html#requests-service-account) - HTTPX - - [AsyncOAuth1Client](https://docs.authlib.org/en/latest/client/httpx.html#httpx-oauth-1-0) - - [AsyncOAuth2Client](https://docs.authlib.org/en/latest/client/httpx.html#httpx-oauth-2-0) - - [OpenID Connect](https://docs.authlib.org/en/latest/client/httpx.html#httpx-oauth-2-0) - - [AsyncAssertionClient](https://docs.authlib.org/en/latest/client/httpx.html#async-service-account) -- [Flask OAuth Client](https://docs.authlib.org/en/latest/client/flask.html) -- [Django OAuth Client](https://docs.authlib.org/en/latest/client/django.html) -- [Starlette OAuth Client](https://docs.authlib.org/en/latest/client/starlette.html) -- [FastAPI OAuth Client](https://docs.authlib.org/en/latest/client/fastapi.html) + - [AsyncOAuth1Client](https://docs.authlib.org/en/stable/client/httpx.html#httpx-oauth-1-0) + - [AsyncOAuth2Client](https://docs.authlib.org/en/stable/client/httpx.html#httpx-oauth-2-0) + - [OpenID Connect](https://docs.authlib.org/en/stable/client/httpx.html#httpx-oauth-2-0) + - [AsyncAssertionClient](https://docs.authlib.org/en/stable/client/httpx.html#async-service-account) +- [Flask OAuth Client](https://docs.authlib.org/en/stable/client/flask.html) +- [Django OAuth Client](https://docs.authlib.org/en/stable/client/django.html) +- [Starlette OAuth Client](https://docs.authlib.org/en/stable/client/starlette.html) +- [FastAPI OAuth Client](https://docs.authlib.org/en/stable/client/fastapi.html) Build your own OAuth 1.0, OAuth 2.0, and OpenID Connect providers: - Flask - - [Flask OAuth 1.0 Provider](https://docs.authlib.org/en/latest/flask/1/) - - [Flask OAuth 2.0 Provider](https://docs.authlib.org/en/latest/flask/2/) - - [Flask OpenID Connect 1.0 Provider](https://docs.authlib.org/en/latest/flask/2/openid-connect.html) + - [Flask OAuth 1.0 Provider](https://docs.authlib.org/en/stable/flask/1/) + - [Flask OAuth 2.0 Provider](https://docs.authlib.org/en/stable/flask/2/) + - [Flask OpenID Connect 1.0 Provider](https://docs.authlib.org/en/stable/flask/2/openid-connect.html) - Django - - [Django OAuth 1.0 Provider](https://docs.authlib.org/en/latest/django/1/) - - [Django OAuth 2.0 Provider](https://docs.authlib.org/en/latest/django/2/) - - [Django OpenID Connect 1.0 Provider](https://docs.authlib.org/en/latest/django/2/openid-connect.html) - -## Sponsors - - - - - - - - - - - - - - -
If you want to quickly add secure token-based authentication to Python projects, feel free to check Auth0's Python SDK and free plan at auth0.com/developers.
For quickly implementing token-based authentication, feel free to check Authing's Python SDK.
Get professionally-supported Authlib with the Tidelift Subscription. -
- -[**Support Me via GitHub Sponsors**](https://github.com/users/lepture/sponsorship). + - [Django OAuth 1.0 Provider](https://docs.authlib.org/en/stable/django/1/) + - [Django OAuth 2.0 Provider](https://docs.authlib.org/en/stable/django/2/) + - [Django OpenID Connect 1.0 Provider](https://docs.authlib.org/en/stable/django/2/openid-connect.html) ## Useful Links @@ -121,19 +133,10 @@ Tidelift will coordinate the fix and disclosure. Authlib offers two licenses: -1. BSD (LICENSE) +1. BSD LICENSE 2. COMMERCIAL-LICENSE -Companies can purchase a commercial license at -[Authlib Plans](https://authlib.org/plans). - -**If your company is creating a closed source OAuth provider, it is strongly -suggested that your company purchasing a commercial license.** - -## Support - -If you need any help, you can always ask questions on StackOverflow with -a tag of "Authlib". DO NOT ASK HELP IN GITHUB ISSUES. - -We also provide commercial consulting and supports. You can find more -information at . +Any project, open or closed source, can use the BSD license. +If your company needs commercial support, you can purchase a commercial license at +[Authlib Plans](https://authlib.org/plans). You can find more information at +. diff --git a/README.rst b/README.rst deleted file mode 100644 index bb3f29419..000000000 --- a/README.rst +++ /dev/null @@ -1,69 +0,0 @@ -Authlib -======= - -The ultimate Python library in building OAuth and OpenID Connect servers. -JWS, JWK, JWA, JWT are included. - -Useful Links ------------- - -1. Homepage: https://authlib.org/ -2. Documentation: https://docs.authlib.org/ -3. Purchase Commercial License: https://authlib.org/plans -4. Blog: https://blog.authlib.org/ -5. More Repositories: https://github.com/authlib -6. Twitter: https://twitter.com/authlib -7. Donate: https://www.patreon.com/lepture - -Specifications --------------- - -- RFC5849: The OAuth 1.0 Protocol -- RFC6749: The OAuth 2.0 Authorization Framework -- RFC6750: The OAuth 2.0 Authorization Framework: Bearer Token Usage -- RFC7009: OAuth 2.0 Token Revocation -- RFC7515: JSON Web Signature -- RFC7516: JSON Web Encryption -- RFC7517: JSON Web Key -- RFC7518: JSON Web Algorithms -- RFC7519: JSON Web Token -- RFC7521: Assertion Framework for OAuth 2.0 Client Authentication and Authorization Grants -- RFC7523: JSON Web Token (JWT) Profile for OAuth 2.0 Client Authentication and Authorization Grants -- RFC7591: OAuth 2.0 Dynamic Client Registration Protocol -- RFC7636: Proof Key for Code Exchange by OAuth Public Clients -- RFC7638: JSON Web Key (JWK) Thumbprint -- RFC7662: OAuth 2.0 Token Introspection -- RFC8037: CFRG Elliptic Curve Diffie-Hellman (ECDH) and Signatures in JSON Object Signing and Encryption (JOSE) -- RFC8414: OAuth 2.0 Authorization Server Metadata -- RFC8628: OAuth 2.0 Device Authorization Grant -- OpenID Connect 1.0 -- OpenID Connect Discovery 1.0 - -Implementations ---------------- - -- Requests OAuth 1 Session -- Requests OAuth 2 Session -- Requests Assertion Session -- HTTPX OAuth 1 Session -- HTTPX OAuth 2 Session -- HTTPX Assertion Session -- Flask OAuth 1/2 Client -- Django OAuth 1/2 Client -- Starlette OAuth 1/2 Client -- Flask OAuth 1.0 Server -- Flask OAuth 2.0 Server -- Flask OpenID Connect 1.0 -- Django OAuth 1.0 Server -- Django OAuth 2.0 Server -- Django OpenID Connect 1.0 - -License -------- - -Authlib is licensed under BSD. Please see LICENSE for licensing details. - -If this license does not fit your company, consider to purchase a commercial -license. Find more information on `Authlib Plans`_. - -.. _`Authlib Plans`: https://authlib.org/plans diff --git a/authlib/__init__.py b/authlib/__init__.py index 2a2e5adc5..e30ed448b 100644 --- a/authlib/__init__.py +++ b/authlib/__init__.py @@ -1,17 +1,20 @@ """ - authlib - ~~~~~~~ +authlib +~~~~~~~ - The ultimate Python library in building OAuth 1.0, OAuth 2.0 and OpenID - Connect clients and providers. It covers from low level specification - implementation to high level framework integrations. +The ultimate Python library in building OAuth 1.0, OAuth 2.0 and OpenID +Connect clients and providers. It covers from low level specification +implementation to high level framework integrations. - :copyright: (c) 2017 by Hsiaoming Yang. - :license: BSD, see LICENSE for more details. +:copyright: (c) 2017 by Hsiaoming Yang. +:license: BSD, see LICENSE for more details. """ -from .consts import version, homepage, author + +from .consts import author +from .consts import homepage +from .consts import version __version__ = version __homepage__ = homepage __author__ = author -__license__ = 'BSD-3-Clause' +__license__ = "BSD-3-Clause" diff --git a/authlib/_joserfc_helpers.py b/authlib/_joserfc_helpers.py new file mode 100644 index 000000000..e26b64c1d --- /dev/null +++ b/authlib/_joserfc_helpers.py @@ -0,0 +1,45 @@ +from typing import Any + +from joserfc.jwk import KeySet +from joserfc.jwk import import_key + +from authlib.common.encoding import json_loads +from authlib.deprecate import deprecate +from authlib.jose import ECKey +from authlib.jose import OctKey +from authlib.jose import OKPKey +from authlib.jose import RSAKey + + +def import_any_key(data: Any): + if isinstance(data, (OctKey, RSAKey, ECKey, OKPKey)): + deprecate("Please use joserfc to import keys.", version="2.0.0") + return import_key(data.as_dict(is_private=not data.public_only)) + + if ( + isinstance(data, str) + and data.strip().startswith("{") + and data.strip().endswith("}") + ): + deprecate( + "Please use OctKey, RSAKey, ECKey, OKPKey, and KeySet directly.", + version="2.0.0", + ) + data = json_loads(data) + + if isinstance(data, (str, bytes)): + deprecate( + "Please use OctKey, RSAKey, ECKey, OKPKey, and KeySet directly.", + version="2.0.0", + ) + return import_key(data) + + elif isinstance(data, dict): + if "keys" in data: + deprecate( + "Please `KeySet.import_key_set` from `joserfc.jwk` to import jwks.", + version="2.0.0", + ) + return KeySet.import_key_set(data) + return import_key(data) + return data diff --git a/authlib/common/encoding.py b/authlib/common/encoding.py index 31df0b039..25063dc25 100644 --- a/authlib/common/encoding.py +++ b/authlib/common/encoding.py @@ -1,45 +1,31 @@ -import sys -import json import base64 +import json import struct -is_py2 = sys.version_info[0] == 2 - -if is_py2: - unicode_type = unicode # noqa: F821 - byte_type = str - text_types = (unicode, str) # noqa: F821 -else: - unicode_type = str - byte_type = bytes - text_types = (str, ) - -def to_bytes(x, charset='utf-8', errors='strict'): +def to_bytes(x, charset="utf-8", errors="strict"): if x is None: return None - if isinstance(x, byte_type): + if isinstance(x, bytes): return x - if isinstance(x, unicode_type): + if isinstance(x, str): return x.encode(charset, errors) if isinstance(x, (int, float)): return str(x).encode(charset, errors) - return byte_type(x) + return bytes(x) -def to_unicode(x, charset='utf-8', errors='strict'): - if x is None or isinstance(x, unicode_type): +def to_unicode(x, charset="utf-8", errors="strict"): + if x is None or isinstance(x, str): return x - if isinstance(x, byte_type): + if isinstance(x, bytes): return x.decode(charset, errors) - return unicode_type(x) + return str(x) -def to_native(x, encoding='ascii'): +def to_native(x, encoding="ascii"): if isinstance(x, str): return x - if is_py2: - return x.encode(encoding) return x.decode(encoding) @@ -48,37 +34,29 @@ def json_loads(s): def json_dumps(data, ensure_ascii=False): - return json.dumps(data, ensure_ascii=ensure_ascii, separators=(',', ':')) + return json.dumps(data, ensure_ascii=ensure_ascii, separators=(",", ":")) def urlsafe_b64decode(s): - s += b'=' * (-len(s) % 4) + s += b"=" * (-len(s) % 4) return base64.urlsafe_b64decode(s) def urlsafe_b64encode(s): - return base64.urlsafe_b64encode(s).rstrip(b'=') + return base64.urlsafe_b64encode(s).rstrip(b"=") def base64_to_int(s): - data = urlsafe_b64decode(to_bytes(s, charset='ascii')) - buf = struct.unpack('%sB' % len(data), data) - return int(''.join(["%02x" % byte for byte in buf]), 16) + data = urlsafe_b64decode(to_bytes(s, charset="ascii")) + buf = struct.unpack(f"{len(data)}B", data) + return int("".join([f"{byte:02x}" for byte in buf]), 16) def int_to_base64(num): if num < 0: - raise ValueError('Must be a positive integer') - - if hasattr(int, 'to_bytes'): - s = num.to_bytes((num.bit_length() + 7) // 8, 'big', signed=False) - else: - buf = [] - while num: - num, remainder = divmod(num, 256) - buf.append(remainder) - buf.reverse() - s = struct.pack('%sB' % len(buf), *buf) + raise ValueError("Must be a positive integer") + + s = num.to_bytes((num.bit_length() + 7) // 8, "big", signed=False) return to_unicode(urlsafe_b64encode(s)) diff --git a/authlib/common/errors.py b/authlib/common/errors.py index 015ab4beb..ece95896c 100644 --- a/authlib/common/errors.py +++ b/authlib/common/errors.py @@ -1,4 +1,3 @@ -#: coding: utf-8 from authlib.consts import default_json_headers @@ -8,7 +7,7 @@ class AuthlibBaseError(Exception): #: short-string error code error = None #: long-string to describe this error - description = '' + description = "" #: web page that describes this error uri = None @@ -20,57 +19,44 @@ def __init__(self, error=None, description=None, uri=None): if uri is not None: self.uri = uri - message = '{}: {}'.format(self.error, self.description) - super(AuthlibBaseError, self).__init__(message) + message = f"{self.error}: {self.description}" + super().__init__(message) def __repr__(self): - return '<{} "{}">'.format(self.__class__.__name__, self.error) + return f'<{self.__class__.__name__} "{self.error}">' class AuthlibHTTPError(AuthlibBaseError): #: HTTP status code status_code = 400 - def __init__(self, error=None, description=None, uri=None, - status_code=None): - super(AuthlibHTTPError, self).__init__(error, description, uri) + def __init__(self, error=None, description=None, uri=None, status_code=None): + super().__init__(error, description, uri) if status_code is not None: self.status_code = status_code - self._translations = None - self._error_uris = None - - def gettext(self, s): - if self._translations: - return self._translations.gettext(s) - return s def get_error_description(self): return self.description - def get_error_uri(self): - if self.uri: - return self.uri - if self._error_uris: - return self._error_uris.get(self.error) - def get_body(self): - error = [('error', self.error)] + error = [("error", self.error)] - description = self.get_error_description() - if description: - error.append(('error_description', description)) + if self.description: + error.append(("error_description", self.description)) - uri = self.get_error_uri() - if uri: - error.append(('error_uri', uri)) + if self.uri: + error.append(("error_uri", self.uri)) return error def get_headers(self): return default_json_headers[:] - def __call__(self, translations=None, error_uris=None): - self._translations = translations - self._error_uris = error_uris + def __call__(self, uri=None): + self.uri = uri body = dict(self.get_body()) headers = self.get_headers() return self.status_code, body, headers + + +class ContinueIteration(AuthlibBaseError): + pass diff --git a/authlib/common/security.py b/authlib/common/security.py index b05ea1443..2dd5e32c6 100644 --- a/authlib/common/security.py +++ b/authlib/common/security.py @@ -1,19 +1,21 @@ import os -import string import random +import string UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits def generate_token(length=30, chars=UNICODE_ASCII_CHARACTER_SET): rand = random.SystemRandom() - return ''.join(rand.choice(chars) for _ in range(length)) + return "".join(rand.choice(chars) for _ in range(length)) def is_secure_transport(uri): """Check if the uri is over ssl.""" - if os.getenv('AUTHLIB_INSECURE_TRANSPORT'): + if os.getenv("AUTHLIB_INSECURE_TRANSPORT"): return True uri = uri.lower() - return uri.startswith(('https://', 'http://localhost:')) + return uri.startswith( + ("https://", "http://localhost:", "http://127.0.0.1:", "http://[::1]:") + ) diff --git a/authlib/common/urls.py b/authlib/common/urls.py index d03b1735c..e2a8b8559 100644 --- a/authlib/common/urls.py +++ b/authlib/common/urls.py @@ -1,41 +1,21 @@ -""" - authlib.util.urls - ~~~~~~~~~~~~~~~~~ +"""authlib.util.urls. +~~~~~~~~~~~~~~~~~ - Wrapper functions for URL encoding and decoding. +Wrapper functions for URL encoding and decoding. """ import re -try: - from urllib import quote as _quote - from urllib import unquote as _unquote - from urllib import urlencode as _urlencode -except ImportError: - from urllib.parse import quote as _quote - from urllib.parse import unquote as _unquote - from urllib.parse import urlencode as _urlencode - -try: - from urllib2 import parse_keqv_list # noqa: F401 - from urllib2 import parse_http_list # noqa: F401 -except ImportError: - from urllib.request import parse_keqv_list # noqa: F401 - from urllib.request import parse_http_list # noqa: F401 - -try: - import urlparse -except ImportError: - import urllib.parse as urlparse - -from .encoding import to_unicode, to_bytes - -always_safe = ( - 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' - 'abcdefghijklmnopqrstuvwxyz' - '0123456789_.-' -) -urlencoded = set(always_safe) | set('=&;:%+~,*@!()/?') -INVALID_HEX_PATTERN = re.compile(r'%[^0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]') +import urllib.parse as urlparse +from urllib.parse import quote as _quote +from urllib.parse import unquote as _unquote +from urllib.parse import urlencode as _urlencode + +from .encoding import to_bytes +from .encoding import to_unicode + +always_safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-" +urlencoded = set(always_safe) | set("=&;:%+~,*@!()/?") +INVALID_HEX_PATTERN = re.compile(r"%[^0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]") def url_encode(params): @@ -56,11 +36,13 @@ def url_decode(query): """ # Check if query contains invalid characters if query and not set(query) <= urlencoded: - error = ("Error trying to decode a non urlencoded string. " - "Found invalid characters: %s " - "in the string: '%s'. " - "Please ensure the request/response body is " - "x-www-form-urlencoded.") + error = ( + "Error trying to decode a non urlencoded string. " + "Found invalid characters: %s " + "in the string: '%s'. " + "Please ensure the request/response body is " + "x-www-form-urlencoded." + ) raise ValueError(error % (set(query) - urlencoded, query)) # Check for correctly hex encoded values using a regular expression @@ -68,7 +50,7 @@ def url_decode(query): # correct = %00, %A0, %0A, %FF # invalid = %G0, %5H, %PO if INVALID_HEX_PATTERN.search(query): - raise ValueError('Invalid hex encoding in query string.') + raise ValueError("Invalid hex encoding in query string.") # We encode to utf-8 prior to parsing because parse_qsl behaves # differently on unicode input in python 2 and 3. @@ -116,7 +98,7 @@ def add_params_to_uri(uri, params, fragment=False): return urlparse.urlunparse((sch, net, path, par, query, fra)) -def quote(s, safe=b'/'): +def quote(s, safe=b"/"): return to_unicode(_quote(to_bytes(s), safe)) @@ -125,7 +107,7 @@ def unquote(s): def quote_url(s): - return quote(s, b'~@#$&()*!+=:;,.?/\'') + return quote(s, b"~@#$&()*!+=:;,.?/'") def extract_params(raw): @@ -157,6 +139,8 @@ def extract_params(raw): return None -def is_valid_url(url): +def is_valid_url(url: str, fragments_allowed=True): parsed = urlparse.urlparse(url) - return parsed.scheme and parsed.hostname + return ( + parsed.scheme and parsed.hostname and (fragments_allowed or not parsed.fragment) + ) diff --git a/authlib/consts.py b/authlib/consts.py index 41622baec..f4a532e47 100644 --- a/authlib/consts.py +++ b/authlib/consts.py @@ -1,11 +1,11 @@ -name = 'Authlib' -version = '0.15.dev' -author = 'Hsiaoming Yang ' -homepage = 'https://authlib.org/' -default_user_agent = '{}/{} (+{})'.format(name, version, homepage) +name = "Authlib" +version = "1.7.0" +author = "Hsiaoming Yang " +homepage = "https://authlib.org" +default_user_agent = f"{name}/{version} (+{homepage})" default_json_headers = [ - ('Content-Type', 'application/json'), - ('Cache-Control', 'no-store'), - ('Pragma', 'no-cache'), + ("Content-Type", "application/json"), + ("Cache-Control", "no-store"), + ("Pragma", "no-cache"), ] diff --git a/authlib/deprecate.py b/authlib/deprecate.py index ba87f3c32..745494f72 100644 --- a/authlib/deprecate.py +++ b/authlib/deprecate.py @@ -5,12 +5,11 @@ class AuthlibDeprecationWarning(DeprecationWarning): pass -warnings.simplefilter('always', AuthlibDeprecationWarning) +warnings.simplefilter("always", AuthlibDeprecationWarning) -def deprecate(message, version=None, link_uid=None, link_file=None): +def deprecate(message, version=None, stacklevel=3): if version: - message += '\nIt will be compatible before version {}.'.format(version) - if link_uid and link_file: - message += '\nRead more '.format(link_uid, link_file) - warnings.warn(AuthlibDeprecationWarning(message), stacklevel=2) + message += f"\nIt will be compatible before version {version}." + + warnings.warn(AuthlibDeprecationWarning(message), stacklevel=stacklevel) diff --git a/authlib/integrations/base_client/__init__.py b/authlib/integrations/base_client/__init__.py index 4fa35b8a3..e9e352dba 100644 --- a/authlib/integrations/base_client/__init__.py +++ b/authlib/integrations/base_client/__init__.py @@ -1,16 +1,29 @@ -from .base_oauth import BaseOAuth -from .base_app import BaseApp -from .remote_app import RemoteApp +from .errors import InvalidTokenError +from .errors import MismatchingStateError +from .errors import MissingRequestTokenError +from .errors import MissingTokenError +from .errors import OAuthError +from .errors import TokenExpiredError +from .errors import UnsupportedTokenTypeError from .framework_integration import FrameworkIntegration -from .errors import ( - OAuthError, MissingRequestTokenError, MissingTokenError, - TokenExpiredError, InvalidTokenError, UnsupportedTokenTypeError, - MismatchingStateError, -) +from .registry import BaseOAuth +from .sync_app import BaseApp +from .sync_app import OAuth1Mixin +from .sync_app import OAuth2Mixin +from .sync_openid import OpenIDMixin __all__ = [ - 'BaseOAuth', 'BaseApp', 'RemoteApp', 'FrameworkIntegration', - 'OAuthError', 'MissingRequestTokenError', 'MissingTokenError', - 'TokenExpiredError', 'InvalidTokenError', 'UnsupportedTokenTypeError', - 'MismatchingStateError', + "BaseOAuth", + "BaseApp", + "OAuth1Mixin", + "OAuth2Mixin", + "OpenIDMixin", + "FrameworkIntegration", + "OAuthError", + "MissingRequestTokenError", + "MissingTokenError", + "TokenExpiredError", + "InvalidTokenError", + "UnsupportedTokenTypeError", + "MismatchingStateError", ] diff --git a/authlib/integrations/base_client/async_app.py b/authlib/integrations/base_client/async_app.py index 60d3d734c..e755ab557 100644 --- a/authlib/integrations/base_client/async_app.py +++ b/authlib/integrations/base_client/async_app.py @@ -1,27 +1,70 @@ -import time import logging +import time + from authlib.common.urls import urlparse -from authlib.jose import JsonWebToken, JsonWebKey -from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken -from .base_app import BaseApp -from .errors import ( - MissingRequestTokenError, - MissingTokenError, -) -__all__ = ['AsyncRemoteApp'] +from .errors import MissingRequestTokenError +from .errors import MissingTokenError +from .sync_app import OAuth1Base +from .sync_app import OAuth2Base log = logging.getLogger(__name__) +__all__ = ["AsyncOAuth1Mixin", "AsyncOAuth2Mixin"] -class AsyncRemoteApp(BaseApp): - async def load_server_metadata(self): - if self._server_metadata_url and '_loaded_at' not in self.server_metadata: - metadata = await self._fetch_server_metadata(self._server_metadata_url) - metadata['_loaded_at'] = time.time() - self.server_metadata.update(metadata) - return self.server_metadata +class AsyncOAuth1Mixin(OAuth1Base): + async def request(self, method, url, token=None, **kwargs): + async with self._get_oauth_client() as session: + return await _http_request(self, session, method, url, token, kwargs) + + async def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + if not self.authorize_url: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + async with self._get_oauth_client() as client: + client.redirect_uri = redirect_uri + params = {} + if self.request_token_params: + params.update(self.request_token_params) + request_token = await client.fetch_request_token( + self.request_token_url, **params + ) + log.debug(f"Fetch request token: {request_token!r}") + url = client.create_authorization_url(self.authorize_url, **kwargs) + state = request_token["oauth_token"] + return {"url": url, "request_token": request_token, "state": state} + + async def fetch_access_token(self, request_token=None, **kwargs): + """Fetch access token in one step. + + :param request_token: A previous request token for OAuth 1. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + async with self._get_oauth_client() as client: + if request_token is None: + raise MissingRequestTokenError() + # merge request token with verifier + token = {} + token.update(request_token) + token.update(kwargs) + client.token = token + params = self.access_token_params or {} + token = await client.fetch_access_token(self.access_token_url, **params) + return token + + +class AsyncOAuth2Mixin(OAuth2Base): async def _on_update_token(self, token, refresh_token=None, access_token=None): if self._update_token: await self._update_token( @@ -30,16 +73,22 @@ async def _on_update_token(self, token, refresh_token=None, access_token=None): access_token=access_token, ) - async def _create_oauth1_authorization_url(self, client, authorization_endpoint, **kwargs): - params = {} - if self.request_token_params: - params.update(self.request_token_params) - token = await client.fetch_request_token( - self.request_token_url, **params - ) - log.debug('Fetch request token: {!r}'.format(token)) - url = client.create_authorization_url(authorization_endpoint, **kwargs) - return {'url': url, 'request_token': token} + async def load_server_metadata(self): + if self._server_metadata_url and "_loaded_at" not in self.server_metadata: + async with self._get_session() as client: + resp = await client.request( + "GET", self._server_metadata_url, withhold_token=True + ) + resp.raise_for_status() + metadata = resp.json() + metadata["_loaded_at"] = time.time() + self.server_metadata.update(metadata) + return self.server_metadata + + async def request(self, method, url, token=None, **kwargs): + metadata = await self.load_server_metadata() + async with self._get_oauth_client(**metadata) as session: + return await _http_request(self, session, method, url, token, kwargs) async def create_authorization_url(self, redirect_uri=None, **kwargs): """Generate the authorization url and state for HTTP redirect. @@ -49,10 +98,9 @@ async def create_authorization_url(self, redirect_uri=None, **kwargs): :return: dict """ metadata = await self.load_server_metadata() - authorization_endpoint = self.authorize_url - if not authorization_endpoint and not self.request_token_url: - authorization_endpoint = metadata.get('authorization_endpoint') - + authorization_endpoint = self.authorize_url or metadata.get( + "authorization_endpoint" + ) if not authorization_endpoint: raise RuntimeError('Missing "authorize_url" value') @@ -61,145 +109,44 @@ async def create_authorization_url(self, redirect_uri=None, **kwargs): async with self._get_oauth_client(**metadata) as client: client.redirect_uri = redirect_uri + return self._create_oauth2_authorization_url( + client, authorization_endpoint, **kwargs + ) - if self.request_token_url: - return await self._create_oauth1_authorization_url( - client, authorization_endpoint, **kwargs) - else: - return self._create_oauth2_authorization_url( - client, authorization_endpoint, **kwargs) - - async def fetch_access_token(self, redirect_uri=None, request_token=None, **params): - """Fetch access token in one step. + async def fetch_access_token(self, redirect_uri=None, **kwargs): + """Fetch access token in the final step. :param redirect_uri: Callback or Redirect URI that is used in previous :meth:`authorize_redirect`. - :param request_token: A previous request token for OAuth 1. - :param params: Extra parameters to fetch access token. + :param kwargs: Extra parameters to fetch access token. :return: A token dict. """ metadata = await self.load_server_metadata() - token_endpoint = self.access_token_url - if not token_endpoint and not self.request_token_url: - token_endpoint = metadata.get('token_endpoint') - + token_endpoint = self.access_token_url or metadata.get("token_endpoint") async with self._get_oauth_client(**metadata) as client: - if self.request_token_url: - client.redirect_uri = redirect_uri - if request_token is None: - raise MissingRequestTokenError() - # merge request token with verifier - token = {} - token.update(request_token) - token.update(params) - client.token = token - kwargs = self.access_token_params or {} - token = await client.fetch_access_token(token_endpoint, **kwargs) - client.redirect_uri = None - else: + if redirect_uri is not None: client.redirect_uri = redirect_uri - kwargs = {} - if self.access_token_params: - kwargs.update(self.access_token_params) - kwargs.update(params) - token = await client.fetch_token(token_endpoint, **kwargs) - return token - - async def request(self, method, url, token=None, **kwargs): - if self.api_base_url and not url.startswith(('https://', 'http://')): - url = urlparse.urljoin(self.api_base_url, url) - - withhold_token = kwargs.get('withhold_token') - if not withhold_token: - metadata = await self.load_server_metadata() - else: - metadata = {} - - async with self._get_oauth_client(**metadata) as client: - request = kwargs.pop('request', None) - - if withhold_token: - return await client.request(method, url, **kwargs) - - if token is None and request: - token = await self._fetch_token(request) - - if token is None: - raise MissingTokenError() - - client.token = token - return await client.request(method, url, **kwargs) - - async def userinfo(self, **kwargs): - """Fetch user info from ``userinfo_endpoint``.""" - metadata = await self.load_server_metadata() - resp = await self.get(metadata['userinfo_endpoint'], **kwargs) - data = resp.json() - - compliance_fix = metadata.get('userinfo_compliance_fix') - if compliance_fix: - data = await compliance_fix(self, data) - return UserInfo(data) - - async def _parse_id_token(self, token, nonce, claims_options=None): - """Return an instance of UserInfo from token's ``id_token``.""" - claims_params = dict( - nonce=nonce, - client_id=self.client_id, - ) - if 'access_token' in token: - claims_params['access_token'] = token['access_token'] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken - - metadata = await self.load_server_metadata() - if claims_options is None and 'issuer' in metadata: - claims_options = {'iss': {'values': [metadata['issuer']]}} - - alg_values = metadata.get('id_token_signing_alg_values_supported') - if not alg_values: - alg_values = ['RS256'] - - jwt = JsonWebToken(alg_values) - - jwk_set = await self._fetch_jwk_set() - try: - claims = jwt.decode( - token['id_token'], - key=JsonWebKey.import_key_set(jwk_set), - claims_cls=claims_cls, - claims_options=claims_options, - claims_params=claims_params, - ) - except ValueError: - jwk_set = await self._fetch_jwk_set(force=True) - claims = jwt.decode( - token['id_token'], - key=JsonWebKey.import_key_set(jwk_set), - claims_cls=claims_cls, - claims_options=claims_options, - claims_params=claims_params, - ) - - claims.validate(leeway=120) - return UserInfo(claims) - - async def _fetch_jwk_set(self, force=False): - metadata = await self.load_server_metadata() - jwk_set = metadata.get('jwks') - if jwk_set and not force: - return jwk_set - - uri = metadata.get('jwks_uri') - if not uri: - raise RuntimeError('Missing "jwks_uri" in metadata') - - jwk_set = await self._fetch_server_metadata(uri) - self.server_metadata['jwks'] = jwk_set - return jwk_set - - async def _fetch_server_metadata(self, url): - async with self._get_oauth_client() as client: - resp = await client.request('GET', url, withhold_token=True) - return resp.json() + params = {} + if self.access_token_params: + params.update(self.access_token_params) + params.update(kwargs) + token = await client.fetch_token(token_endpoint, **params) + return token + + +async def _http_request(ctx, session, method, url, token, kwargs): + request = kwargs.pop("request", None) + withhold_token = kwargs.get("withhold_token") + if ctx.api_base_url and not url.startswith(("https://", "http://")): + url = urlparse.urljoin(ctx.api_base_url, url) + + if withhold_token: + return await session.request(method, url, **kwargs) + + if token is None and ctx._fetch_token and request: + token = await ctx._fetch_token(request) + if token is None: + raise MissingTokenError() + + session.token = token + return await session.request(method, url, **kwargs) diff --git a/authlib/integrations/base_client/async_openid.py b/authlib/integrations/base_client/async_openid.py new file mode 100644 index 000000000..0a983a2a1 --- /dev/null +++ b/authlib/integrations/base_client/async_openid.py @@ -0,0 +1,123 @@ +from joserfc import jwt +from joserfc.errors import InvalidKeyIdError +from joserfc.jwk import KeySet + +from authlib.common.security import generate_token +from authlib.common.urls import add_params_to_uri +from authlib.oidc.core import CodeIDToken +from authlib.oidc.core import ImplicitIDToken +from authlib.oidc.core import UserInfo + +__all__ = ["AsyncOpenIDMixin"] + + +class AsyncOpenIDMixin: + async def fetch_jwk_set(self, force=False): + metadata = await self.load_server_metadata() + jwk_set = metadata.get("jwks") + if jwk_set and not force: + return jwk_set + + uri = metadata.get("jwks_uri") + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + async with self._get_session() as client: + resp = await client.request("GET", uri, withhold_token=True) + resp.raise_for_status() + jwk_set = resp.json() + + self.server_metadata["jwks"] = jwk_set + return jwk_set + + async def userinfo(self, **kwargs): + """Fetch user info from ``userinfo_endpoint``.""" + metadata = await self.load_server_metadata() + resp = await self.get(metadata["userinfo_endpoint"], **kwargs) + resp.raise_for_status() + data = resp.json() + return UserInfo(data) + + async def parse_id_token( + self, token, nonce, claims_options=None, claims_cls=None, leeway=120 + ): + """Return an instance of UserInfo from token's ``id_token``.""" + claims_params = dict( + nonce=nonce, + client_id=self.client_id, + ) + if claims_cls is None: + if "access_token" in token: + claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + metadata = await self.load_server_metadata() + if claims_options is None and "issuer" in metadata: + claims_options = {"iss": {"values": [metadata["issuer"]]}} + + alg_values = metadata.get("id_token_signing_alg_values_supported") + if not alg_values: + alg_values = ["RS256"] + + jwks = await self.fetch_jwk_set() + key_set = KeySet.import_key_set(jwks) + try: + token = jwt.decode( + token["id_token"], + key=key_set, + algorithms=alg_values, + ) + except InvalidKeyIdError: + jwks = await self.fetch_jwk_set(force=True) + key_set = KeySet.import_key_set(jwks) + token = jwt.decode( + token["id_token"], + key=key_set, + algorithms=alg_values, + ) + + claims = claims_cls(token.claims, token.header, claims_options, claims_params) + # https://github.com/authlib/authlib/issues/259 + if claims.get("nonce_supported") is False: + claims.params["nonce"] = None + claims.validate(leeway=leeway) + return UserInfo(claims) + + async def create_logout_url( + self, + post_logout_redirect_uri=None, + id_token_hint=None, + state=None, + **kwargs, + ): + """Generate the end session URL for RP-Initiated Logout. + + :param post_logout_redirect_uri: URI to redirect after logout. + :param id_token_hint: ID Token previously issued to the RP. + :param state: Opaque value for maintaining state. + :param kwargs: Extra parameters (client_id, logout_hint, ui_locales). + :return: dict with 'url' and 'state' keys. + """ + metadata = await self.load_server_metadata() + end_session_endpoint = metadata.get("end_session_endpoint") + + if not end_session_endpoint: + raise RuntimeError('Missing "end_session_endpoint" in metadata') + + params = {} + if id_token_hint: + params["id_token_hint"] = id_token_hint + if post_logout_redirect_uri: + params["post_logout_redirect_uri"] = post_logout_redirect_uri + if state is None: + state = generate_token(20) + params["state"] = state + + for key in ("client_id", "logout_hint", "ui_locales"): + if key in kwargs: + params[key] = kwargs[key] + + url = add_params_to_uri(end_session_endpoint, params) + return {"url": url, "state": state} diff --git a/authlib/integrations/base_client/base_app.py b/authlib/integrations/base_client/base_app.py deleted file mode 100644 index 3df09a106..000000000 --- a/authlib/integrations/base_client/base_app.py +++ /dev/null @@ -1,244 +0,0 @@ -import logging - -from authlib.common.security import generate_token -from authlib.consts import default_user_agent -from .errors import ( - MismatchingStateError, -) - -__all__ = ['BaseApp'] - -log = logging.getLogger(__name__) - - -class BaseApp(object): - """The remote application for OAuth 1 and OAuth 2. It is used together - with OAuth registry. - - :param name: The name of the OAuth client, like: github, twitter - :param fetch_token: A function to fetch access token from database - :param update_token: A function to update access token to database - :param client_id: Client key of OAuth 1, or Client ID of OAuth 2 - :param client_secret: Client secret of OAuth 2, or Client Secret of OAuth 2 - :param request_token_url: Request Token endpoint for OAuth 1 - :param request_token_params: Extra parameters for Request Token endpoint - :param access_token_url: Access Token endpoint for OAuth 1 and OAuth 2 - :param access_token_params: Extra parameters for Access Token endpoint - :param authorize_url: Endpoint for user authorization of OAuth 1 or OAuth 2 - :param authorize_params: Extra parameters for Authorization Endpoint - :param api_base_url: The base API endpoint to make requests simple - :param client_kwargs: Extra keyword arguments for session - :param server_metadata_url: Discover server metadata from this URL - :param user_agent: Define a custom user agent to be used in HTTP request - :param kwargs: Extra server metadata - - Create an instance of ``RemoteApp``. If ``request_token_url`` is configured, - it would be an OAuth 1 instance, otherwise it is OAuth 2 instance:: - - oauth1_client = RemoteApp( - client_id='Twitter Consumer Key', - client_secret='Twitter Consumer Secret', - request_token_url='https://api.twitter.com/oauth/request_token', - access_token_url='https://api.twitter.com/oauth/access_token', - authorize_url='https://api.twitter.com/oauth/authenticate', - api_base_url='https://api.twitter.com/1.1/', - ) - - oauth2_client = RemoteApp( - client_id='GitHub Client ID', - client_secret='GitHub Client Secret', - api_base_url='https://api.github.com/', - access_token_url='https://github.com/login/oauth/access_token', - authorize_url='https://github.com/login/oauth/authorize', - client_kwargs={'scope': 'user:email'}, - ) - """ - OAUTH_APP_CONFIG = None - - def __init__( - self, framework, name=None, fetch_token=None, update_token=None, - client_id=None, client_secret=None, - request_token_url=None, request_token_params=None, - access_token_url=None, access_token_params=None, - authorize_url=None, authorize_params=None, - api_base_url=None, client_kwargs=None, server_metadata_url=None, - compliance_fix=None, client_auth_methods=None, user_agent=None, **kwargs): - - self.framework = framework - self.name = name - self.client_id = client_id - self.client_secret = client_secret - self.request_token_url = request_token_url - self.request_token_params = request_token_params - self.access_token_url = access_token_url - self.access_token_params = access_token_params - self.authorize_url = authorize_url - self.authorize_params = authorize_params - self.api_base_url = api_base_url - self.client_kwargs = client_kwargs or {} - - self.compliance_fix = compliance_fix - self.client_auth_methods = client_auth_methods - self._fetch_token = fetch_token - self._update_token = update_token - self._user_agent = user_agent or default_user_agent - - self._server_metadata_url = server_metadata_url - self.server_metadata = kwargs - - def _on_update_token(self, token, refresh_token=None, access_token=None): - raise NotImplementedError() - - def _get_oauth_client(self, **kwargs): - client_kwargs = {} - client_kwargs.update(self.client_kwargs) - client_kwargs.update(kwargs) - if self.request_token_url: - session = self.framework.oauth1_client_cls( - self.client_id, self.client_secret, - **client_kwargs - ) - else: - if self.authorize_url: - client_kwargs['authorization_endpoint'] = self.authorize_url - if self.access_token_url: - client_kwargs['token_endpoint'] = self.access_token_url - session = self.framework.oauth2_client_cls( - client_id=self.client_id, - client_secret=self.client_secret, - update_token=self._on_update_token, - **client_kwargs - ) - if self.client_auth_methods: - for f in self.client_auth_methods: - session.register_client_auth_method(f) - # only OAuth2 has compliance_fix currently - if self.compliance_fix: - self.compliance_fix(session) - - session.headers['User-Agent'] = self._user_agent - return session - - def _retrieve_oauth2_access_token_params(self, request, params): - request_state = params.pop('state', None) - state = self.framework.get_session_data(request, 'state') - if state != request_state: - raise MismatchingStateError() - if state: - params['state'] = state - - code_verifier = self.framework.get_session_data(request, 'code_verifier') - if code_verifier: - params['code_verifier'] = code_verifier - return params - - def retrieve_access_token_params(self, request, request_token=None): - """Retrieve parameters for fetching access token, those parameters come - from request and previously saved temporary data in session. - """ - params = self.framework.generate_access_token_params(self.request_token_url, request) - if self.request_token_url: - if request_token is None: - request_token = self.framework.get_session_data(request, 'request_token') - params['request_token'] = request_token - else: - params = self._retrieve_oauth2_access_token_params(request, params) - - redirect_uri = self.framework.get_session_data(request, 'redirect_uri') - if redirect_uri: - params['redirect_uri'] = redirect_uri - - log.debug('Retrieve temporary data: {!r}'.format(params)) - return params - - def save_authorize_data(self, request, **kwargs): - """Save temporary data into session for the authorization step. These - data can be retrieved later when fetching access token. - """ - log.debug('Saving authorize data: {!r}'.format(kwargs)) - keys = [ - 'redirect_uri', 'request_token', - 'state', 'code_verifier', 'nonce' - ] - for k in keys: - if k in kwargs: - self.framework.set_session_data(request, k, kwargs[k]) - - @staticmethod - def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs): - rv = {} - if client.code_challenge_method: - code_verifier = kwargs.get('code_verifier') - if not code_verifier: - code_verifier = generate_token(48) - kwargs['code_verifier'] = code_verifier - rv['code_verifier'] = code_verifier - log.debug('Using code_verifier: {!r}'.format(code_verifier)) - - scope = kwargs.get('scope', client.scope) - if scope and scope.startswith('openid'): - # this is an OpenID Connect service - nonce = kwargs.get('nonce') - if not nonce: - nonce = generate_token(20) - kwargs['nonce'] = nonce - rv['nonce'] = nonce - - url, state = client.create_authorization_url( - authorization_endpoint, **kwargs) - rv['url'] = url - rv['state'] = state - return rv - - def request(self, method, url, token=None, **kwargs): - raise NotImplementedError() - - def get(self, url, **kwargs): - """Invoke GET http request. - - If ``api_base_url`` configured, shortcut is available:: - - client.get('users/lepture') - """ - return self.request('GET', url, **kwargs) - - def post(self, url, **kwargs): - """Invoke POST http request. - - If ``api_base_url`` configured, shortcut is available:: - - client.post('timeline', json={'text': 'Hi'}) - """ - return self.request('POST', url, **kwargs) - - def patch(self, url, **kwargs): - """Invoke PATCH http request. - - If ``api_base_url`` configured, shortcut is available:: - - client.patch('profile', json={'name': 'Hsiaoming Yang'}) - """ - return self.request('PATCH', url, **kwargs) - - def put(self, url, **kwargs): - """Invoke PUT http request. - - If ``api_base_url`` configured, shortcut is available:: - - client.put('profile', json={'name': 'Hsiaoming Yang'}) - """ - return self.request('PUT', url, **kwargs) - - def delete(self, url, **kwargs): - """Invoke DELETE http request. - - If ``api_base_url`` configured, shortcut is available:: - - client.delete('posts/123') - """ - return self.request('DELETE', url, **kwargs) - - def _fetch_server_metadata(self, url): - with self._get_oauth_client() as session: - resp = session.request('GET', url, withhold_token=True) - return resp.json() diff --git a/authlib/integrations/base_client/errors.py b/authlib/integrations/base_client/errors.py index bb4dd2b12..4d5078c28 100644 --- a/authlib/integrations/base_client/errors.py +++ b/authlib/integrations/base_client/errors.py @@ -2,29 +2,29 @@ class OAuthError(AuthlibBaseError): - error = 'oauth_error' + error = "oauth_error" class MissingRequestTokenError(OAuthError): - error = 'missing_request_token' + error = "missing_request_token" class MissingTokenError(OAuthError): - error = 'missing_token' + error = "missing_token" class TokenExpiredError(OAuthError): - error = 'token_expired' + error = "token_expired" class InvalidTokenError(OAuthError): - error = 'token_invalid' + error = "token_invalid" class UnsupportedTokenTypeError(OAuthError): - error = 'unsupported_token_type' + error = "unsupported_token_type" class MismatchingStateError(OAuthError): - error = 'mismatching_state' - description = 'CSRF Warning! State not equal in request and response.' + error = "mismatching_state" + description = "CSRF Warning! State not equal in request and response." diff --git a/authlib/integrations/base_client/framework_integration.py b/authlib/integrations/base_client/framework_integration.py index 2f27689c2..3ca43c02e 100644 --- a/authlib/integrations/base_client/framework_integration.py +++ b/authlib/integrations/base_client/framework_integration.py @@ -1,23 +1,63 @@ +import json +import time -class FrameworkIntegration(object): - oauth1_client_cls = None - oauth2_client_cls = None - def __init__(self, name): +class FrameworkIntegration: + expires_in = 3600 + + def __init__(self, name, cache=None): self.name = name + self.cache = cache - def set_session_data(self, request, key, value): - sess_key = '_{}_authlib_{}_'.format(self.name, key) - request.session[sess_key] = value + def _get_cache_data(self, key): + value = self.cache.get(key) + if not value: + return None + try: + return json.loads(value) + except (TypeError, ValueError): + return None - def get_session_data(self, request, key): - sess_key = '_{}_authlib_{}_'.format(self.name, key) - return request.session.pop(sess_key, None) + def _clear_session_state(self, session): + now = time.time() + prefix = f"_state_{self.name}" + for key in dict(session): + if key.startswith(prefix): + value = session[key] + exp = value.get("exp") + if not exp or exp < now: + session.pop(key) - def update_token(self, token, refresh_token=None, access_token=None): - raise NotImplementedError() + def get_state_data(self, session, state): + key = f"_state_{self.name}_{state}" + session_data = session.get(key) + if not session_data: + return None + if self.cache: + cached_value = self._get_cache_data(key) + else: + cached_value = session_data + if cached_value: + return cached_value.get("data") + return None - def generate_access_token_params(self, request_token_url, request): + def set_state_data(self, session, state, data): + key = f"_state_{self.name}_{state}" + now = time.time() + if self.cache: + self.cache.set(key, json.dumps({"data": data}), self.expires_in) + session[key] = {"exp": now + self.expires_in} + else: + session[key] = {"data": data, "exp": now + self.expires_in} + + def clear_state_data(self, session, state): + key = f"_state_{self.name}_{state}" + if self.cache: + self.cache.delete(key) + session.pop(key, None) + self._clear_session_state(session) + + def update_token(self, token, refresh_token=None, access_token=None): raise NotImplementedError() @staticmethod diff --git a/authlib/integrations/base_client/base_oauth.py b/authlib/integrations/base_client/registry.py similarity index 67% rename from authlib/integrations/base_client/base_oauth.py rename to authlib/integrations/base_client/registry.py index 36c027b01..407448289 100644 --- a/authlib/integrations/base_client/base_oauth.py +++ b/authlib/integrations/base_client/registry.py @@ -1,33 +1,43 @@ import functools + from .framework_integration import FrameworkIntegration -__all__ = ['BaseOAuth'] +__all__ = ["BaseOAuth"] OAUTH_CLIENT_PARAMS = ( - 'client_id', 'client_secret', - 'request_token_url', 'request_token_params', - 'access_token_url', 'access_token_params', - 'refresh_token_url', 'refresh_token_params', - 'authorize_url', 'authorize_params', - 'api_base_url', 'client_kwargs', - 'server_metadata_url', + "client_id", + "client_secret", + "request_token_url", + "request_token_params", + "access_token_url", + "access_token_params", + "refresh_token_url", + "refresh_token_params", + "authorize_url", + "authorize_params", + "api_base_url", + "client_kwargs", + "server_metadata_url", ) -class BaseOAuth(object): +class BaseOAuth: """Registry for oauth clients. Create an instance for registry:: oauth = OAuth() """ - framework_client_cls = None + + oauth1_client_cls = None + oauth2_client_cls = None framework_integration_cls = FrameworkIntegration - def __init__(self, fetch_token=None, update_token=None): + def __init__(self, cache=None, fetch_token=None, update_token=None): self._registry = {} self._clients = {} + self.cache = cache self.fetch_token = fetch_token self.update_token = update_token @@ -36,7 +46,7 @@ def create_client(self, name): OAuth registry has ``.register`` a twitter client, developers may access the client with:: - client = oauth.create_client('twitter') + client = oauth.create_client("twitter") :param: name: Name of the remote application :return: OAuth remote app @@ -48,14 +58,23 @@ def create_client(self, name): return None overwrite, config = self._registry[name] - client_cls = config.pop('client_cls', self.framework_client_cls) - if client_cls.OAUTH_APP_CONFIG: + client_cls = config.pop("client_cls", None) + + if client_cls and client_cls.OAUTH_APP_CONFIG: kwargs = client_cls.OAUTH_APP_CONFIG kwargs.update(config) else: kwargs = config + kwargs = self.generate_client_kwargs(name, overwrite, **kwargs) - client = client_cls(self.framework_integration_cls(name), name, **kwargs) + framework = self.framework_integration_cls(name, self.cache) + if client_cls: + client = client_cls(framework, name, **kwargs) + elif kwargs.get("request_token_url"): + client = self.oauth1_client_cls(framework, name, **kwargs) + else: + client = self.oauth2_client_cls(framework, name, **kwargs) + self._clients[name] = client return client @@ -76,8 +95,8 @@ def register(self, name, overwrite=False, **kwargs): return self.create_client(name) def generate_client_kwargs(self, name, overwrite, **kwargs): - fetch_token = kwargs.pop('fetch_token', None) - update_token = kwargs.pop('update_token', None) + fetch_token = kwargs.pop("fetch_token", None) + update_token = kwargs.pop("update_token", None) config = self.load_config(name, OAUTH_CLIENT_PARAMS) if config: @@ -86,13 +105,13 @@ def generate_client_kwargs(self, name, overwrite, **kwargs): if not fetch_token and self.fetch_token: fetch_token = functools.partial(self.fetch_token, name) - kwargs['fetch_token'] = fetch_token + kwargs["fetch_token"] = fetch_token - if not kwargs.get('request_token_url'): + if not kwargs.get("request_token_url"): if not update_token and self.update_token: update_token = functools.partial(self.update_token, name) - kwargs['update_token'] = update_token + kwargs["update_token"] = update_token return kwargs def load_config(self, name, params): @@ -101,10 +120,10 @@ def load_config(self, name, params): def __getattr__(self, key): try: return object.__getattribute__(self, key) - except AttributeError: + except AttributeError as exc: if key in self._registry: return self.create_client(key) - raise AttributeError('No such client: %s' % key) + raise AttributeError(f"No such client: {key}") from exc def _config_client(config, kwargs, overwrite): diff --git a/authlib/integrations/base_client/remote_app.py b/authlib/integrations/base_client/remote_app.py deleted file mode 100644 index 430d56f3a..000000000 --- a/authlib/integrations/base_client/remote_app.py +++ /dev/null @@ -1,204 +0,0 @@ -import time -import logging -from authlib.common.urls import urlparse -from authlib.jose import JsonWebToken, JsonWebKey -from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken -from .base_app import BaseApp -from .errors import ( - MissingRequestTokenError, - MissingTokenError, -) - -__all__ = ['RemoteApp'] - -log = logging.getLogger(__name__) - - -class RemoteApp(BaseApp): - def load_server_metadata(self): - if self._server_metadata_url and '_loaded_at' not in self.server_metadata: - metadata = self._fetch_server_metadata(self._server_metadata_url) - metadata['_loaded_at'] = time.time() - self.server_metadata.update(metadata) - return self.server_metadata - - def _on_update_token(self, token, refresh_token=None, access_token=None): - if callable(self._update_token): - self._update_token( - token, - refresh_token=refresh_token, - access_token=access_token, - ) - self.framework.update_token( - token, - refresh_token=refresh_token, - access_token=access_token, - ) - - def _create_oauth1_authorization_url(self, client, authorization_endpoint, **kwargs): - params = {} - if self.request_token_params: - params.update(self.request_token_params) - token = client.fetch_request_token( - self.request_token_url, **params - ) - log.debug('Fetch request token: {!r}'.format(token)) - url = client.create_authorization_url(authorization_endpoint, **kwargs) - return {'url': url, 'request_token': token} - - def create_authorization_url(self, redirect_uri=None, **kwargs): - """Generate the authorization url and state for HTTP redirect. - - :param redirect_uri: Callback or redirect URI for authorization. - :param kwargs: Extra parameters to include. - :return: dict - """ - metadata = self.load_server_metadata() - authorization_endpoint = self.authorize_url - if not authorization_endpoint and not self.request_token_url: - authorization_endpoint = metadata.get('authorization_endpoint') - - if not authorization_endpoint: - raise RuntimeError('Missing "authorize_url" value') - - if self.authorize_params: - kwargs.update(self.authorize_params) - - with self._get_oauth_client(**metadata) as client: - client.redirect_uri = redirect_uri - - if self.request_token_url: - return self._create_oauth1_authorization_url( - client, authorization_endpoint, **kwargs) - else: - return self._create_oauth2_authorization_url( - client, authorization_endpoint, **kwargs) - - def fetch_access_token(self, redirect_uri=None, request_token=None, **params): - """Fetch access token in one step. - - :param redirect_uri: Callback or Redirect URI that is used in - previous :meth:`authorize_redirect`. - :param request_token: A previous request token for OAuth 1. - :param params: Extra parameters to fetch access token. - :return: A token dict. - """ - metadata = self.load_server_metadata() - token_endpoint = self.access_token_url - if not token_endpoint and not self.request_token_url: - token_endpoint = metadata.get('token_endpoint') - - with self._get_oauth_client(**metadata) as client: - if self.request_token_url: - client.redirect_uri = redirect_uri - if request_token is None: - raise MissingRequestTokenError() - # merge request token with verifier - token = {} - token.update(request_token) - token.update(params) - client.token = token - kwargs = self.access_token_params or {} - token = client.fetch_access_token(token_endpoint, **kwargs) - client.redirect_uri = None - else: - client.redirect_uri = redirect_uri - kwargs = {} - if self.access_token_params: - kwargs.update(self.access_token_params) - kwargs.update(params) - token = client.fetch_token(token_endpoint, **kwargs) - return token - - def request(self, method, url, token=None, **kwargs): - if self.api_base_url and not url.startswith(('https://', 'http://')): - url = urlparse.urljoin(self.api_base_url, url) - - withhold_token = kwargs.get('withhold_token') - if not withhold_token: - metadata = self.load_server_metadata() - else: - metadata = {} - - with self._get_oauth_client(**metadata) as session: - request = kwargs.pop('request', None) - if withhold_token: - return session.request(method, url, **kwargs) - - if token is None and self._fetch_token and request: - token = self._fetch_token(request) - if token is None: - raise MissingTokenError() - - session.token = token - return session.request(method, url, **kwargs) - - def fetch_jwk_set(self, force=False): - metadata = self.load_server_metadata() - jwk_set = metadata.get('jwks') - if jwk_set and not force: - return jwk_set - uri = metadata.get('jwks_uri') - if not uri: - raise RuntimeError('Missing "jwks_uri" in metadata') - - jwk_set = self._fetch_server_metadata(uri) - self.server_metadata['jwks'] = jwk_set - return jwk_set - - def userinfo(self, **kwargs): - """Fetch user info from ``userinfo_endpoint``.""" - metadata = self.load_server_metadata() - resp = self.get(metadata['userinfo_endpoint'], **kwargs) - data = resp.json() - - compliance_fix = metadata.get('userinfo_compliance_fix') - if compliance_fix: - data = compliance_fix(self, data) - return UserInfo(data) - - def _parse_id_token(self, request, token, claims_options=None, leeway=120): - """Return an instance of UserInfo from token's ``id_token``.""" - if 'id_token' not in token: - return None - - def load_key(header, payload): - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) - try: - return jwk_set.find_by_kid(header.get('kid')) - except ValueError: - # re-try with new jwk set - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) - return jwk_set.find_by_kid(header.get('kid')) - - nonce = self.framework.get_session_data(request, 'nonce') - claims_params = dict( - nonce=nonce, - client_id=self.client_id, - ) - if 'access_token' in token: - claims_params['access_token'] = token['access_token'] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken - - metadata = self.load_server_metadata() - if claims_options is None and 'issuer' in metadata: - claims_options = {'iss': {'values': [metadata['issuer']]}} - - alg_values = metadata.get('id_token_signing_alg_values_supported') - if not alg_values: - alg_values = ['RS256'] - - jwt = JsonWebToken(alg_values) - claims = jwt.decode( - token['id_token'], key=load_key, - claims_cls=claims_cls, - claims_options=claims_options, - claims_params=claims_params, - ) - # https://github.com/lepture/authlib/issues/259 - if claims.get('nonce_supported') is False: - claims.params['nonce'] = None - claims.validate(leeway=leeway) - return UserInfo(claims) diff --git a/authlib/integrations/base_client/sync_app.py b/authlib/integrations/base_client/sync_app.py new file mode 100644 index 000000000..3c8f32494 --- /dev/null +++ b/authlib/integrations/base_client/sync_app.py @@ -0,0 +1,382 @@ +import logging +import time + +from authlib.common.security import generate_token +from authlib.common.urls import urlparse +from authlib.consts import default_user_agent + +from .errors import MismatchingStateError +from .errors import MissingRequestTokenError +from .errors import MissingTokenError + +log = logging.getLogger(__name__) + + +class BaseApp: + client_cls = None + OAUTH_APP_CONFIG = None + + def request(self, method, url, token=None, **kwargs): + raise NotImplementedError() + + def get(self, url, **kwargs): + """Invoke GET http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.get("users/lepture") + """ + return self.request("GET", url, **kwargs) + + def post(self, url, **kwargs): + """Invoke POST http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.post("timeline", json={"text": "Hi"}) + """ + return self.request("POST", url, **kwargs) + + def patch(self, url, **kwargs): + """Invoke PATCH http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.patch("profile", json={"name": "Hsiaoming Yang"}) + """ + return self.request("PATCH", url, **kwargs) + + def put(self, url, **kwargs): + """Invoke PUT http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.put("profile", json={"name": "Hsiaoming Yang"}) + """ + return self.request("PUT", url, **kwargs) + + def delete(self, url, **kwargs): + """Invoke DELETE http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.delete("posts/123") + """ + return self.request("DELETE", url, **kwargs) + + +class _RequestMixin: + def _get_requested_token(self, request): + if self._fetch_token and request: + return self._fetch_token(request) + + def _send_token_request(self, session, method, url, token, kwargs): + request = kwargs.pop("request", None) + withhold_token = kwargs.get("withhold_token") + if self.api_base_url and not url.startswith(("https://", "http://")): + url = urlparse.urljoin(self.api_base_url, url) + + if withhold_token: + return session.request(method, url, **kwargs) + + if token is None: + token = self._get_requested_token(request) + + if token is None: + raise MissingTokenError() + + session.token = token + return session.request(method, url, **kwargs) + + +class OAuth1Base: + client_cls = None + + def __init__( + self, + framework, + name=None, + fetch_token=None, + client_id=None, + client_secret=None, + request_token_url=None, + request_token_params=None, + access_token_url=None, + access_token_params=None, + authorize_url=None, + authorize_params=None, + api_base_url=None, + client_kwargs=None, + user_agent=None, + **kwargs, + ): + self.framework = framework + self.name = name + self.client_id = client_id + self.client_secret = client_secret + self.request_token_url = request_token_url + self.request_token_params = request_token_params + self.access_token_url = access_token_url + self.access_token_params = access_token_params + self.authorize_url = authorize_url + self.authorize_params = authorize_params + self.api_base_url = api_base_url + self.client_kwargs = client_kwargs or {} + + self._fetch_token = fetch_token + self._user_agent = user_agent or default_user_agent + self._kwargs = kwargs + + def _get_oauth_client(self): + session = self.client_cls( + self.client_id, self.client_secret, **self.client_kwargs + ) + session.headers["User-Agent"] = self._user_agent + return session + + +class OAuth1Mixin(_RequestMixin, OAuth1Base): + def request(self, method, url, token=None, **kwargs): + with self._get_oauth_client() as session: + return self._send_token_request(session, method, url, token, kwargs) + + def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + if not self.authorize_url: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + with self._get_oauth_client() as client: + client.redirect_uri = redirect_uri + params = self.request_token_params or {} + request_token = client.fetch_request_token(self.request_token_url, **params) + log.debug(f"Fetch request token: {request_token!r}") + url = client.create_authorization_url(self.authorize_url, **kwargs) + state = request_token["oauth_token"] + return {"url": url, "request_token": request_token, "state": state} + + def fetch_access_token(self, request_token=None, **kwargs): + """Fetch access token in one step. + + :param request_token: A previous request token for OAuth 1. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + with self._get_oauth_client() as client: + if request_token is None: + raise MissingRequestTokenError() + # merge request token with verifier + token = {} + token.update(request_token) + token.update(kwargs) + client.token = token + params = self.access_token_params or {} + token = client.fetch_access_token(self.access_token_url, **params) + return token + + +class OAuth2Base: + client_cls = None + + def __init__( + self, + framework, + name=None, + fetch_token=None, + update_token=None, + client_id=None, + client_secret=None, + access_token_url=None, + access_token_params=None, + authorize_url=None, + authorize_params=None, + api_base_url=None, + client_kwargs=None, + server_metadata_url=None, + compliance_fix=None, + client_auth_methods=None, + user_agent=None, + **kwargs, + ): + self.framework = framework + self.name = name + self.client_id = client_id + self.client_secret = client_secret + self.access_token_url = access_token_url + self.access_token_params = access_token_params + self.authorize_url = authorize_url + self.authorize_params = authorize_params + self.api_base_url = api_base_url + self.client_kwargs = client_kwargs or {} + + self.compliance_fix = compliance_fix + self.client_auth_methods = client_auth_methods + self._fetch_token = fetch_token + self._update_token = update_token + self._user_agent = user_agent or default_user_agent + + self._server_metadata_url = server_metadata_url + self.server_metadata = kwargs + + def _on_update_token(self, token, refresh_token=None, access_token=None): + raise NotImplementedError() + + def _get_session(self): + session = self.client_cls(**self.client_kwargs) + session.headers["User-Agent"] = self._user_agent + return session + + def _get_oauth_client(self, **metadata): + client_kwargs = {} + client_kwargs.update(self.client_kwargs) + client_kwargs.update(metadata) + + if self.authorize_url: + client_kwargs["authorization_endpoint"] = self.authorize_url + if self.access_token_url: + client_kwargs["token_endpoint"] = self.access_token_url + + session = self.client_cls( + client_id=self.client_id, + client_secret=self.client_secret, + update_token=self._on_update_token, + **client_kwargs, + ) + if self.client_auth_methods: + for f in self.client_auth_methods: + session.register_client_auth_method(f) + + if self.compliance_fix: + self.compliance_fix(session) + + session.headers["User-Agent"] = self._user_agent + return session + + @staticmethod + def _format_state_params(state_data, params): + if state_data is None: + raise MismatchingStateError() + + code_verifier = state_data.get("code_verifier") + if code_verifier: + params["code_verifier"] = code_verifier + + redirect_uri = state_data.get("redirect_uri") + if redirect_uri: + params["redirect_uri"] = redirect_uri + return params + + @staticmethod + def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs): + rv = {} + if client.code_challenge_method: + code_verifier = kwargs.get("code_verifier") + if not code_verifier: + code_verifier = generate_token(48) + kwargs["code_verifier"] = code_verifier + rv["code_verifier"] = code_verifier + log.debug(f"Using code_verifier: {code_verifier!r}") + + scope = kwargs.get("scope", client.scope) + scope = ( + (scope if isinstance(scope, (list, tuple)) else scope.split()) + if scope + else None + ) + if scope and "openid" in scope: + # this is an OpenID Connect service + nonce = kwargs.get("nonce") + if not nonce: + nonce = generate_token(20) + kwargs["nonce"] = nonce + rv["nonce"] = nonce + + url, state = client.create_authorization_url(authorization_endpoint, **kwargs) + rv["url"] = url + rv["state"] = state + return rv + + +class OAuth2Mixin(_RequestMixin, OAuth2Base): + def _on_update_token(self, token, refresh_token=None, access_token=None): + if callable(self._update_token): + self._update_token( + token, + refresh_token=refresh_token, + access_token=access_token, + ) + self.framework.update_token( + token, + refresh_token=refresh_token, + access_token=access_token, + ) + + def request(self, method, url, token=None, **kwargs): + metadata = self.load_server_metadata() + with self._get_oauth_client(**metadata) as session: + return self._send_token_request(session, method, url, token, kwargs) + + def load_server_metadata(self): + if self._server_metadata_url and "_loaded_at" not in self.server_metadata: + with self._get_session() as session: + resp = session.request( + "GET", self._server_metadata_url, withhold_token=True + ) + resp.raise_for_status() + metadata = resp.json() + + metadata["_loaded_at"] = time.time() + self.server_metadata.update(metadata) + return self.server_metadata + + def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + metadata = self.load_server_metadata() + authorization_endpoint = self.authorize_url or metadata.get( + "authorization_endpoint" + ) + + if not authorization_endpoint: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + with self._get_oauth_client(**metadata) as client: + if redirect_uri is not None: + client.redirect_uri = redirect_uri + return self._create_oauth2_authorization_url( + client, authorization_endpoint, **kwargs + ) + + def fetch_access_token(self, redirect_uri=None, **kwargs): + """Fetch access token in the final step. + + :param redirect_uri: Callback or Redirect URI that is used in + previous :meth:`authorize_redirect`. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + metadata = self.load_server_metadata() + token_endpoint = self.access_token_url or metadata.get("token_endpoint") + with self._get_oauth_client(**metadata) as client: + if redirect_uri is not None: + client.redirect_uri = redirect_uri + params = {} + if self.access_token_params: + params.update(self.access_token_params) + params.update(kwargs) + token = client.fetch_token(token_endpoint, **params) + return token diff --git a/authlib/integrations/base_client/sync_openid.py b/authlib/integrations/base_client/sync_openid.py new file mode 100644 index 000000000..42e5f8271 --- /dev/null +++ b/authlib/integrations/base_client/sync_openid.py @@ -0,0 +1,122 @@ +from joserfc import jwt +from joserfc.errors import InvalidKeyIdError +from joserfc.jwk import KeySet + +from authlib.common.security import generate_token +from authlib.common.urls import add_params_to_uri +from authlib.oidc.core import CodeIDToken +from authlib.oidc.core import ImplicitIDToken +from authlib.oidc.core import UserInfo + + +class OpenIDMixin: + def fetch_jwk_set(self, force=False): + metadata = self.load_server_metadata() + jwk_set = metadata.get("jwks") + if jwk_set and not force: + return jwk_set + + uri = metadata.get("jwks_uri") + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + with self._get_session() as session: + resp = session.request("GET", uri, withhold_token=True) + resp.raise_for_status() + jwk_set = resp.json() + + self.server_metadata["jwks"] = jwk_set + return jwk_set + + def userinfo(self, **kwargs): + """Fetch user info from ``userinfo_endpoint``.""" + metadata = self.load_server_metadata() + resp = self.get(metadata["userinfo_endpoint"], **kwargs) + resp.raise_for_status() + data = resp.json() + return UserInfo(data) + + def parse_id_token( + self, token, nonce, claims_options=None, claims_cls=None, leeway=120 + ): + """Return an instance of UserInfo from token's ``id_token``.""" + if "id_token" not in token: + return None + + claims_params = dict( + nonce=nonce, + client_id=self.client_id, + ) + + if claims_cls is None: + if "access_token" in token: + claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + metadata = self.load_server_metadata() + if claims_options is None and "issuer" in metadata: + claims_options = {"iss": {"values": [metadata["issuer"]]}} + + alg_values = metadata.get("id_token_signing_alg_values_supported") + + key_set = KeySet.import_key_set(self.fetch_jwk_set()) + try: + token = jwt.decode( + token["id_token"], + key=key_set, + algorithms=alg_values, + ) + except InvalidKeyIdError: + key_set = KeySet.import_key_set(self.fetch_jwk_set(force=True)) + token = jwt.decode( + token["id_token"], + key=key_set, + algorithms=alg_values, + ) + + claims = claims_cls(token.claims, token.header, claims_options, claims_params) + # https://github.com/authlib/authlib/issues/259 + if claims.get("nonce_supported") is False: + claims.params["nonce"] = None + + claims.validate(leeway=leeway) + return UserInfo(claims) + + def create_logout_url( + self, + post_logout_redirect_uri=None, + id_token_hint=None, + state=None, + **kwargs, + ): + """Generate the end session URL for RP-Initiated Logout. + + :param post_logout_redirect_uri: URI to redirect after logout. + :param id_token_hint: ID Token previously issued to the RP. + :param state: Opaque value for maintaining state. + :param kwargs: Extra parameters (client_id, logout_hint, ui_locales). + :return: dict with 'url' and 'state' keys. + """ + metadata = self.load_server_metadata() + end_session_endpoint = metadata.get("end_session_endpoint") + + if not end_session_endpoint: + raise RuntimeError('Missing "end_session_endpoint" in metadata') + + params = {} + if id_token_hint: + params["id_token_hint"] = id_token_hint + if post_logout_redirect_uri: + params["post_logout_redirect_uri"] = post_logout_redirect_uri + if state is None: + state = generate_token(20) + params["state"] = state + + for key in ("client_id", "logout_hint", "ui_locales"): + if key in kwargs: + params[key] = kwargs[key] + + url = add_params_to_uri(end_session_endpoint, params) + return {"url": url, "state": state} diff --git a/authlib/integrations/django_client/__init__.py b/authlib/integrations/django_client/__init__.py index 18a30ca4a..28b5ff071 100644 --- a/authlib/integrations/django_client/__init__.py +++ b/authlib/integrations/django_client/__init__.py @@ -1,15 +1,22 @@ -# flake8: noqa - -from .integration import DjangoIntegration, DjangoRemoteApp, token_update -from ..base_client import BaseOAuth, OAuthError +from ..base_client import BaseOAuth +from ..base_client import OAuthError +from .apps import DjangoOAuth1App +from .apps import DjangoOAuth2App +from .integration import DjangoIntegration +from .integration import token_update class OAuth(BaseOAuth): + oauth1_client_cls = DjangoOAuth1App + oauth2_client_cls = DjangoOAuth2App framework_integration_cls = DjangoIntegration - framework_client_cls = DjangoRemoteApp __all__ = [ - 'OAuth', 'DjangoRemoteApp', 'DjangoIntegration', - 'token_update', 'OAuthError', + "OAuth", + "DjangoOAuth1App", + "DjangoOAuth2App", + "DjangoIntegration", + "token_update", + "OAuthError", ] diff --git a/authlib/integrations/django_client/apps.py b/authlib/integrations/django_client/apps.py new file mode 100644 index 000000000..632cd8775 --- /dev/null +++ b/authlib/integrations/django_client/apps.py @@ -0,0 +1,143 @@ +from django.http import HttpResponseRedirect + +from ..base_client import BaseApp +from ..base_client import OAuth1Mixin +from ..base_client import OAuth2Mixin +from ..base_client import OAuthError +from ..base_client import OpenIDMixin +from ..requests_client import OAuth1Session +from ..requests_client import OAuth2Session + + +class DjangoAppMixin: + def save_authorize_data(self, request, **kwargs): + state = kwargs.pop("state", None) + if state: + self.framework.set_state_data(request.session, state, kwargs) + else: + raise RuntimeError("Missing state value") + + def authorize_redirect(self, request, redirect_uri=None, **kwargs): + """Create a HTTP Redirect for Authorization Endpoint. + + :param request: HTTP request instance from Django view. + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: A HTTP redirect response. + """ + rv = self.create_authorization_url(redirect_uri, **kwargs) + self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) + return HttpResponseRedirect(rv["url"]) + + +class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp): + client_cls = OAuth1Session + + def authorize_access_token(self, request, **kwargs): + """Fetch access token in one step. + + :param request: HTTP request instance from Django view. + :return: A token dict. + """ + params = request.GET.dict() + state = params.get("oauth_token") + if not state: + raise OAuthError(description='Missing "oauth_token" parameter') + + data = self.framework.get_state_data(request.session, state) + if not data: + raise OAuthError(description='Missing "request_token" in temporary data') + + params["request_token"] = data["request_token"] + params.update(kwargs) + self.framework.clear_state_data(request.session, state) + return self.fetch_access_token(**params) + + +class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp): + client_cls = OAuth2Session + + def logout_redirect( + self, request, post_logout_redirect_uri=None, id_token_hint=None, **kwargs + ): + """Create a HTTP Redirect for End Session Endpoint (RP-Initiated Logout). + + :param request: HTTP request instance from Django view. + :param post_logout_redirect_uri: URI to redirect after logout. + :param id_token_hint: ID Token previously issued to the RP. + :param kwargs: Extra parameters (state, client_id, logout_hint, ui_locales). + :return: A HTTP redirect response. + """ + result = self.create_logout_url( + post_logout_redirect_uri=post_logout_redirect_uri, + id_token_hint=id_token_hint, + **kwargs, + ) + if result.get("state"): + self.framework.set_state_data( + request.session, + result["state"], + { + "post_logout_redirect_uri": post_logout_redirect_uri, + }, + ) + return HttpResponseRedirect(result["url"]) + + def validate_logout_response(self, request): + """Validate the state parameter from the logout callback. + + :param request: HTTP request instance from Django view. + :return: The state data dict. + :raises OAuthError: If state is missing or invalid. + """ + state = request.GET.get("state") + if not state: + raise OAuthError(description='Missing "state" parameter') + + state_data = self.framework.get_state_data(request.session, state) + if not state_data: + raise OAuthError(description='Invalid "state" parameter') + + self.framework.clear_state_data(request.session, state) + return state_data + + def authorize_access_token(self, request, **kwargs): + """Fetch access token in one step. + + :param request: HTTP request instance from Django view. + :return: A token dict. + """ + if request.method == "GET": + error = request.GET.get("error") + if error: + description = request.GET.get("error_description") + raise OAuthError(error=error, description=description) + params = { + "code": request.GET.get("code"), + "state": request.GET.get("state"), + } + else: + params = { + "code": request.POST.get("code"), + "state": request.POST.get("state"), + } + + state_data = self.framework.get_state_data(request.session, params.get("state")) + self.framework.clear_state_data(request.session, params.get("state")) + params = self._format_state_params(state_data, params) + + claims_options = kwargs.pop("claims_options", None) + claims_cls = kwargs.pop("claims_cls", None) + leeway = kwargs.pop("leeway", 120) + token = self.fetch_access_token(**params, **kwargs) + + if "id_token" in token and "nonce" in state_data: + userinfo = self.parse_id_token( + token, + nonce=state_data["nonce"], + claims_options=claims_options, + claims_cls=claims_cls, + leeway=leeway, + ) + token["userinfo"] = userinfo + return token diff --git a/authlib/integrations/django_client/integration.py b/authlib/integrations/django_client/integration.py index 79d7dbde2..5f7f11dac 100644 --- a/authlib/integrations/django_client/integration.py +++ b/authlib/integrations/django_client/integration.py @@ -1,17 +1,12 @@ from django.conf import settings from django.dispatch import Signal -from django.http import HttpResponseRedirect -from ..base_client import FrameworkIntegration, RemoteApp -from ..requests_client import OAuth1Session, OAuth2Session +from ..base_client import FrameworkIntegration token_update = Signal() class DjangoIntegration(FrameworkIntegration): - oauth1_client_cls = OAuth1Session - oauth2_client_cls = OAuth2Session - def update_token(self, token, refresh_token=None, access_token=None): token_update.send( sender=self.__class__, @@ -21,51 +16,8 @@ def update_token(self, token, refresh_token=None, access_token=None): access_token=access_token, ) - def generate_access_token_params(self, request_token_url, request): - if request_token_url: - return request.GET.dict() - - if request.method == 'GET': - params = { - 'code': request.GET.get('code'), - 'state': request.GET.get('state'), - } - else: - params = { - 'code': request.POST.get('code'), - 'state': request.POST.get('state'), - } - return params - @staticmethod def load_config(oauth, name, params): - config = getattr(settings, 'AUTHLIB_OAUTH_CLIENTS', None) + config = getattr(settings, "AUTHLIB_OAUTH_CLIENTS", None) if config: return config.get(name) - - -class DjangoRemoteApp(RemoteApp): - def authorize_redirect(self, request, redirect_uri=None, **kwargs): - """Create a HTTP Redirect for Authorization Endpoint. - - :param request: HTTP request instance from Django view. - :param redirect_uri: Callback or redirect URI for authorization. - :param kwargs: Extra parameters to include. - :return: A HTTP redirect response. - """ - rv = self.create_authorization_url(redirect_uri, **kwargs) - self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) - return HttpResponseRedirect(rv['url']) - - def authorize_access_token(self, request, **kwargs): - """Fetch access token in one step. - - :param request: HTTP request instance from Django view. - :return: A token dict. - """ - params = self.retrieve_access_token_params(request) - params.update(kwargs) - return self.fetch_access_token(**params) - - def parse_id_token(self, request, token, claims_options=None, leeway=120): - return self._parse_id_token(request, token, claims_options, leeway) diff --git a/authlib/integrations/django_helpers.py b/authlib/integrations/django_helpers.py deleted file mode 100644 index 117958d2b..000000000 --- a/authlib/integrations/django_helpers.py +++ /dev/null @@ -1,72 +0,0 @@ -try: - from collections.abc import MutableMapping as DictMixin -except ImportError: - from collections import MutableMapping as DictMixin -from authlib.common.encoding import to_unicode, json_loads - - -def create_oauth_request(request, request_cls, use_json=False): - if isinstance(request, request_cls): - return request - - if request.method == 'POST': - if use_json: - body = json_loads(request.body) - else: - body = request.POST.dict() - else: - body = None - - headers = parse_request_headers(request) - url = request.get_raw_uri() - return request_cls(request.method, url, body, headers) - - -def parse_request_headers(request): - return WSGIHeaderDict(request.META) - - -class WSGIHeaderDict(DictMixin): - CGI_KEYS = ('CONTENT_TYPE', 'CONTENT_LENGTH') - - def __init__(self, environ): - self.environ = environ - - def keys(self): - return [x for x in self] - - def _ekey(self, key): - key = key.replace('-', '_').upper() - if key in self.CGI_KEYS: - return key - return 'HTTP_' + key - - def __getitem__(self, key): - return _unicode_value(self.environ[self._ekey(key)]) - - def __delitem__(self, key): # pragma: no cover - raise ValueError('Can not delete item') - - def __setitem__(self, key, value): # pragma: no cover - raise ValueError('Can not set item') - - def __iter__(self): - for key in self.environ: - if key[:5] == 'HTTP_': - yield _unify_key(key[5:]) - elif key in self.CGI_KEYS: - yield _unify_key(key) - - def __len__(self): - return len(self.keys()) - - def __contains__(self, key): - return self._ekey(key) in self.environ - - -def _unicode_value(value): - return to_unicode(value, 'latin-1') - - -def _unify_key(key): - return key.replace('_', '-').title() diff --git a/authlib/integrations/django_oauth1/__init__.py b/authlib/integrations/django_oauth1/__init__.py index 39f0e1307..7a479c80c 100644 --- a/authlib/integrations/django_oauth1/__init__.py +++ b/authlib/integrations/django_oauth1/__init__.py @@ -1,9 +1,5 @@ -# flake8: noqa - -from .authorization_server import ( - BaseServer, CacheAuthorizationServer -) +from .authorization_server import BaseServer +from .authorization_server import CacheAuthorizationServer from .resource_protector import ResourceProtector - -__all__ = ['BaseServer', 'CacheAuthorizationServer', 'ResourceProtector'] +__all__ = ["BaseServer", "CacheAuthorizationServer", "ResourceProtector"] diff --git a/authlib/integrations/django_oauth1/authorization_server.py b/authlib/integrations/django_oauth1/authorization_server.py index 0ac8b5c1f..90195b18a 100644 --- a/authlib/integrations/django_oauth1/authorization_server.py +++ b/authlib/integrations/django_oauth1/authorization_server.py @@ -1,16 +1,16 @@ import logging -from authlib.oauth1 import ( - OAuth1Request, - AuthorizationServer as _AuthorizationServer, -) -from authlib.oauth1 import TemporaryCredential -from authlib.common.security import generate_token -from authlib.common.urls import url_encode -from django.core.cache import cache + from django.conf import settings +from django.core.cache import cache from django.http import HttpResponse + +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from authlib.oauth1 import AuthorizationServer as _AuthorizationServer +from authlib.oauth1 import OAuth1Request +from authlib.oauth1 import TemporaryCredential + from .nonce import exists_nonce_in_cache -from ..django_helpers import create_oauth_request log = logging.getLogger(__name__) @@ -21,16 +21,17 @@ def __init__(self, client_model, token_model, token_generator=None): self.token_model = token_model if token_generator is None: + def token_generator(): return { - 'oauth_token': generate_token(42), - 'oauth_token_secret': generate_token(48) + "oauth_token": generate_token(42), + "oauth_token_secret": generate_token(48), } self.token_generator = token_generator - self._config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {}) - self._nonce_expires_in = self._config.get('nonce_expires_in', 86400) - methods = self._config.get('signature_methods') + self._config = getattr(settings, "AUTHLIB_OAUTH1_PROVIDER", {}) + self._nonce_expires_in = self._config.get("nonce_expires_in", 86400) + methods = self._config.get("signature_methods") if methods: self.SUPPORTED_SIGNATURE_METHODS = methods @@ -47,10 +48,10 @@ def create_token_credential(self, request): temporary_credential = request.credential token = self.token_generator() item = self.token_model( - oauth_token=token['oauth_token'], - oauth_token_secret=token['oauth_token_secret'], + oauth_token=token["oauth_token"], + oauth_token_secret=token["oauth_token_secret"], user_id=temporary_credential.get_user_id(), - client_id=temporary_credential.get_client_id() + client_id=temporary_credential.get_client_id(), ) item.save() return item @@ -61,7 +62,12 @@ def check_authorization_request(self, request): return req def create_oauth1_request(self, request): - return create_oauth_request(request, OAuth1Request) + if request.method == "POST": + body = request.POST.dict() + else: + body = None + url = request.build_absolute_uri() + return OAuth1Request(request.method, url, body, request.headers) def handle_response(self, status_code, payload, headers): resp = HttpResponse(url_encode(payload), status=status_code) @@ -72,12 +78,13 @@ def handle_response(self, status_code, payload, headers): class CacheAuthorizationServer(BaseServer): def __init__(self, client_model, token_model, token_generator=None): - super(CacheAuthorizationServer, self).__init__( - client_model, token_model, token_generator) + super().__init__(client_model, token_model, token_generator) self._temporary_expires_in = self._config.get( - 'temporary_credential_expires_in', 86400) + "temporary_credential_expires_in", 86400 + ) self._temporary_credential_key_prefix = self._config.get( - 'temporary_credential_key_prefix', 'temporary_credential:') + "temporary_credential_key_prefix", "temporary_credential:" + ) def create_temporary_credential(self, request): key_prefix = self._temporary_credential_key_prefix @@ -85,10 +92,10 @@ def create_temporary_credential(self, request): client_id = request.client_id redirect_uri = request.redirect_uri - key = key_prefix + token['oauth_token'] - token['client_id'] = client_id + key = key_prefix + token["oauth_token"] + token["client_id"] = client_id if redirect_uri: - token['oauth_callback'] = redirect_uri + token["oauth_callback"] = redirect_uri cache.set(key, token, timeout=self._temporary_expires_in) return TemporaryCredential(token) @@ -115,7 +122,7 @@ def create_authorization_verifier(self, request): credential = request.credential user = request.user key = key_prefix + credential.get_oauth_token() - credential['oauth_verifier'] = verifier - credential['user_id'] = user.pk + credential["oauth_verifier"] = verifier + credential["user_id"] = user.pk cache.set(key, credential, timeout=self._temporary_expires_in) return verifier diff --git a/authlib/integrations/django_oauth1/nonce.py b/authlib/integrations/django_oauth1/nonce.py index 535bf7e6b..a4b21c5f6 100644 --- a/authlib/integrations/django_oauth1/nonce.py +++ b/authlib/integrations/django_oauth1/nonce.py @@ -2,13 +2,13 @@ def exists_nonce_in_cache(nonce, request, timeout): - key_prefix = 'nonce:' + key_prefix = "nonce:" timestamp = request.timestamp client_id = request.client_id token = request.token - key = '{}{}-{}-{}'.format(key_prefix, nonce, timestamp, client_id) + key = f"{key_prefix}{nonce}-{timestamp}-{client_id}" if token: - key = '{}-{}'.format(key, token) + key = f"{key}-{token}" rv = bool(cache.get(key)) cache.set(key, 1, timeout=timeout) diff --git a/authlib/integrations/django_oauth1/resource_protector.py b/authlib/integrations/django_oauth1/resource_protector.py index cc2854b64..89897717e 100644 --- a/authlib/integrations/django_oauth1/resource_protector.py +++ b/authlib/integrations/django_oauth1/resource_protector.py @@ -1,10 +1,12 @@ import functools -from authlib.oauth1.errors import OAuth1Error -from authlib.oauth1 import ResourceProtector as _ResourceProtector -from django.http import JsonResponse + from django.conf import settings +from django.http import JsonResponse + +from authlib.oauth1 import ResourceProtector as _ResourceProtector +from authlib.oauth1.errors import OAuth1Error + from .nonce import exists_nonce_in_cache -from ..django_helpers import parse_request_headers class ResourceProtector(_ResourceProtector): @@ -12,12 +14,12 @@ def __init__(self, client_model, token_model): self.client_model = client_model self.token_model = token_model - config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {}) - methods = config.get('signature_methods', []) + config = getattr(settings, "AUTHLIB_OAUTH1_PROVIDER", {}) + methods = config.get("signature_methods", []) if methods and isinstance(methods, (list, tuple)): self.SUPPORTED_SIGNATURE_METHODS = methods - self._nonce_expires_in = config.get('nonce_expires_in', 86400) + self._nonce_expires_in = config.get("nonce_expires_in", 86400) def get_client_by_id(self, client_id): try: @@ -28,8 +30,7 @@ def get_client_by_id(self, client_id): def get_token_credential(self, request): try: return self.token_model.objects.get( - client_id=request.client_id, - oauth_token=request.token + client_id=request.client_id, oauth_token=request.token ) except self.token_model.DoesNotExist: return None @@ -38,18 +39,17 @@ def exists_nonce(self, nonce, request): return exists_nonce_in_cache(nonce, request, self._nonce_expires_in) def acquire_credential(self, request): - if request.method in ['POST', 'PUT']: + if request.method in ["POST", "PUT"]: body = request.POST.dict() else: body = None - headers = parse_request_headers(request) - url = request.get_raw_uri() - req = self.validate_request(request.method, url, body, headers) + url = request.build_absolute_uri() + req = self.validate_request(request.method, url, body, request.headers) return req.credential def __call__(self, realm=None): - def wrapper(f): + def decorator(f): @functools.wraps(f) def decorated(request, *args, **kwargs): try: @@ -58,9 +58,13 @@ def decorated(request, *args, **kwargs): except OAuth1Error as error: body = dict(error.get_body()) resp = JsonResponse(body, status=error.status_code) - resp['Cache-Control'] = 'no-store' - resp['Pragma'] = 'no-cache' + resp["Cache-Control"] = "no-store" + resp["Pragma"] = "no-cache" return resp return f(request, *args, **kwargs) + return decorated - return wrapper + + if callable(realm): + return decorator(realm) + return decorator diff --git a/authlib/integrations/django_oauth2/__init__.py b/authlib/integrations/django_oauth2/__init__.py index 05c1fdfe9..79b4773a5 100644 --- a/authlib/integrations/django_oauth2/__init__.py +++ b/authlib/integrations/django_oauth2/__init__.py @@ -1,10 +1,9 @@ # flake8: noqa from .authorization_server import AuthorizationServer -from .resource_protector import ResourceProtector, BearerTokenValidator from .endpoints import RevocationEndpoint -from .signals import ( - client_authenticated, - token_authenticated, - token_revoked -) +from .resource_protector import BearerTokenValidator +from .resource_protector import ResourceProtector +from .signals import client_authenticated +from .signals import token_authenticated +from .signals import token_revoked diff --git a/authlib/integrations/django_oauth2/authorization_server.py b/authlib/integrations/django_oauth2/authorization_server.py index fae60aa42..cdae210fe 100644 --- a/authlib/integrations/django_oauth2/authorization_server.py +++ b/authlib/integrations/django_oauth2/authorization_server.py @@ -1,18 +1,16 @@ -import json +from django.conf import settings from django.http import HttpResponse from django.utils.module_loading import import_string -from django.conf import settings -from authlib.oauth2 import ( - OAuth2Request, - HttpRequest, - AuthorizationServer as _AuthorizationServer, -) -from authlib.oauth2.rfc6750 import BearerToken -from authlib.oauth2.rfc8414 import AuthorizationServerMetadata -from authlib.common.security import generate_token as _generate_token + from authlib.common.encoding import json_dumps -from .signals import client_authenticated, token_revoked -from ..django_helpers import create_oauth_request +from authlib.common.security import generate_token as _generate_token +from authlib.oauth2 import AuthorizationServer as _AuthorizationServer +from authlib.oauth2.rfc6750 import BearerTokenGenerator + +from .requests import DjangoJsonRequest +from .requests import DjangoOAuth2Request +from .signals import client_authenticated +from .signals import token_revoked class AuthorizationServer(_AuthorizationServer): @@ -24,33 +22,21 @@ class AuthorizationServer(_AuthorizationServer): server = AuthorizationServer(OAuth2Client, OAuth2Token) """ - metadata_class = AuthorizationServerMetadata - def __init__(self, client_model, token_model, generate_token=None, metadata=None): - self.config = getattr(settings, 'AUTHLIB_OAUTH2_PROVIDER', {}) + def __init__(self, client_model, token_model): + super().__init__() self.client_model = client_model self.token_model = token_model - if generate_token is None: - generate_token = self.create_bearer_token_generator() - - if metadata is None: - metadata_file = self.config.get('metadata_file') - if metadata_file: - with open(metadata_file) as f: - metadata = json.load(f) - - if metadata: - metadata = self.metadata_class(metadata) - metadata.validate() - - super(AuthorizationServer, self).__init__( - query_client=self.get_client_by_id, - save_token=self.save_oauth2_token, - generate_token=generate_token, - metadata=metadata, - ) + self.load_config(getattr(settings, "AUTHLIB_OAUTH2_PROVIDER", {})) - def get_client_by_id(self, client_id): + def load_config(self, config): + self.config = config + scopes_supported = self.config.get("scopes_supported") + self.scopes_supported = scopes_supported + # add default token generator + self.register_token_generator("default", self.create_bearer_token_generator()) + + def query_client(self, client_id): """Default method for ``AuthorizationServer.query_client``. Developers MAY rewrite this function to meet their own needs. """ @@ -59,7 +45,7 @@ def get_client_by_id(self, client_id): except self.client_model.DoesNotExist: return None - def save_oauth2_token(self, token, request): + def save_token(self, token, request): """Default method for ``AuthorizationServer.save_token``. Developers MAY rewrite this function to meet their own needs. """ @@ -68,21 +54,15 @@ def save_oauth2_token(self, token, request): user_id = request.user.pk else: user_id = client.user_id - item = self.token_model( - client_id=client.client_id, - user_id=user_id, - **token - ) + item = self.token_model(client_id=client.client_id, user_id=user_id, **token) item.save() return item def create_oauth2_request(self, request): - return create_oauth_request(request, OAuth2Request) + return DjangoOAuth2Request(request) def create_json_request(self, request): - req = create_oauth_request(request, HttpRequest, True) - req.user = request.user - return req + return DjangoJsonRequest(request) def handle_response(self, status_code, payload, headers): if isinstance(payload, dict): @@ -93,40 +73,28 @@ def handle_response(self, status_code, payload, headers): return resp def send_signal(self, name, *args, **kwargs): - if name == 'after_authenticate_client': - client_authenticated.send(sender=self.__class__, *args, **kwargs) - elif name == 'after_revoke_token': - token_revoked.send(sender=self.__class__, *args, **kwargs) + if name == "after_authenticate_client": + client_authenticated.send(*args, sender=self.__class__, **kwargs) + elif name == "after_revoke_token": + token_revoked.send(*args, sender=self.__class__, **kwargs) def create_bearer_token_generator(self): """Default method to create BearerToken generator.""" - conf = self.config.get('access_token_generator', True) + conf = self.config.get("access_token_generator", True) access_token_generator = create_token_generator(conf, 42) - conf = self.config.get('refresh_token_generator', False) + conf = self.config.get("refresh_token_generator", False) refresh_token_generator = create_token_generator(conf, 48) - conf = self.config.get('token_expires_in') + conf = self.config.get("token_expires_in") expires_generator = create_token_expires_in_generator(conf) - return BearerToken( + return BearerTokenGenerator( access_token_generator=access_token_generator, refresh_token_generator=refresh_token_generator, expires_generator=expires_generator, ) - def get_consent_grant(self, request): - grant = self.get_authorization_grant(request) - grant.validate_consent_request() - if not hasattr(grant, 'prompt'): - grant.prompt = None - return grant - - def validate_consent_request(self, request, end_user=None): - req = self.create_oauth2_request(request) - req.user = end_user - return self.get_consent_grant(req) - def create_token_generator(token_generator_conf, length=42): if callable(token_generator_conf): @@ -135,18 +103,20 @@ def create_token_generator(token_generator_conf, length=42): if isinstance(token_generator_conf, str): return import_string(token_generator_conf) elif token_generator_conf is True: + def token_generator(*args, **kwargs): return _generate_token(length) + return token_generator def create_token_expires_in_generator(expires_in_conf=None): data = {} - data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) + data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN) if expires_in_conf: data.update(expires_in_conf) def expires_in(client, grant_type): - return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) + return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN) return expires_in diff --git a/authlib/integrations/django_oauth2/endpoints.py b/authlib/integrations/django_oauth2/endpoints.py index b3a8ccd37..08a9d4f67 100644 --- a/authlib/integrations/django_oauth2/endpoints.py +++ b/authlib/integrations/django_oauth2/endpoints.py @@ -14,32 +14,29 @@ class RevocationEndpoint(_RevocationEndpoint): # see register into authorization server instance server.register_endpoint(RevocationEndpoint) + @require_http_methods(["POST"]) def revoke_token(request): return server.create_endpoint_response( - RevocationEndpoint.ENDPOINT_NAME, - request + RevocationEndpoint.ENDPOINT_NAME, request ) """ - def query_token(self, token, token_type_hint, client): + def query_token(self, token, token_type_hint): """Query requested token from database.""" token_model = self.server.token_model - if token_type_hint == 'access_token': + if token_type_hint == "access_token": rv = _query_access_token(token_model, token) - elif token_type_hint == 'refresh_token': + elif token_type_hint == "refresh_token": rv = _query_refresh_token(token_model, token) else: rv = _query_access_token(token_model, token) if not rv: rv = _query_refresh_token(token_model, token) - client_id = client.get_client_id() - if rv and rv.client_id == client_id: - return rv - return None + return rv - def revoke_token(self, token): + def revoke_token(self, token, request): """Mark the give token as revoked.""" token.revoked = True token.save() diff --git a/authlib/integrations/django_oauth2/requests.py b/authlib/integrations/django_oauth2/requests.py new file mode 100644 index 000000000..b490cb705 --- /dev/null +++ b/authlib/integrations/django_oauth2/requests.py @@ -0,0 +1,65 @@ +from collections import defaultdict + +from django.http import HttpRequest +from django.utils.functional import cached_property + +from authlib.common.encoding import json_loads +from authlib.oauth2.rfc6749 import JsonPayload +from authlib.oauth2.rfc6749 import JsonRequest +from authlib.oauth2.rfc6749 import OAuth2Payload +from authlib.oauth2.rfc6749 import OAuth2Request + + +class DjangoOAuth2Payload(OAuth2Payload): + def __init__(self, request: HttpRequest): + self._request = request + + @cached_property + def data(self): + data = {} + data.update(self._request.GET.dict()) + data.update(self._request.POST.dict()) + return data + + @cached_property + def datalist(self): + values = defaultdict(list) + for k in self._request.GET: + values[k].extend(self._request.GET.getlist(k)) + for k in self._request.POST: + values[k].extend(self._request.POST.getlist(k)) + return values + + +class DjangoOAuth2Request(OAuth2Request): + def __init__(self, request: HttpRequest): + super().__init__( + method=request.method, + uri=request.build_absolute_uri(), + headers=request.headers, + ) + self.payload = DjangoOAuth2Payload(request) + self._request = request + + @property + def args(self): + return self._request.GET + + @property + def form(self): + return self._request.POST + + +class DjangoJsonPayload(JsonPayload): + def __init__(self, request: HttpRequest): + self._request = request + + @cached_property + def data(self): + return json_loads(self._request.body) + + +class DjangoJsonRequest(JsonRequest): + def __init__(self, request: HttpRequest): + super().__init__(request.method, request.build_absolute_uri(), request.headers) + self.payload = DjangoJsonPayload(request) diff --git a/authlib/integrations/django_oauth2/resource_protector.py b/authlib/integrations/django_oauth2/resource_protector.py index 3e7f78dee..697e9b972 100644 --- a/authlib/integrations/django_oauth2/resource_protector.py +++ b/authlib/integrations/django_oauth2/resource_protector.py @@ -1,44 +1,43 @@ import functools + from django.http import JsonResponse -from authlib.oauth2 import ( - OAuth2Error, - ResourceProtector as _ResourceProtector, -) -from authlib.oauth2.rfc6749 import ( - MissingAuthorizationError, - HttpRequest, -) -from authlib.oauth2.rfc6750 import ( - BearerTokenValidator as _BearerTokenValidator -) + +from authlib.oauth2 import OAuth2Error +from authlib.oauth2 import ResourceProtector as _ResourceProtector +from authlib.oauth2.rfc6749 import MissingAuthorizationError +from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator + +from .requests import DjangoJsonRequest from .signals import token_authenticated -from ..django_helpers import parse_request_headers class ResourceProtector(_ResourceProtector): - def acquire_token(self, request, scope=None, operator='AND'): + def acquire_token(self, request, scopes=None, **kwargs): """A method to acquire current valid token with the given scope. :param request: Django HTTP request instance - :param scope: string or list of scope values - :param operator: value of "AND" or "OR" + :param scopes: a list of scope values :return: token object """ - headers = parse_request_headers(request) - url = request.get_raw_uri() - req = HttpRequest(request.method, url, request.body, headers) - if not callable(operator): - operator = operator.upper() - token = self.validate_request(scope, req, operator) + req = DjangoJsonRequest(request) + # backward compatibility + kwargs["scopes"] = scopes + for claim in kwargs: + if isinstance(kwargs[claim], str): + kwargs[claim] = [kwargs[claim]] + token = self.validate_request(request=req, **kwargs) token_authenticated.send(sender=self.__class__, token=token) return token - def __call__(self, scope=None, operator='AND', optional=False): - def wrapper(f): + def __call__(self, scopes=None, optional=False, **kwargs): + claims = kwargs + claims["scopes"] = scopes if not callable(scopes) else None + + def decorator(f): @functools.wraps(f) def decorated(request, *args, **kwargs): try: - token = self.acquire_token(request, scope, operator) + token = self.acquire_token(request, **claims) request.oauth_token = token except MissingAuthorizationError as error: if optional: @@ -48,14 +47,18 @@ def decorated(request, *args, **kwargs): except OAuth2Error as error: return return_error_response(error) return f(request, *args, **kwargs) + return decorated - return wrapper + + if callable(scopes): + return decorator(scopes) + return decorator class BearerTokenValidator(_BearerTokenValidator): - def __init__(self, token_model, realm=None): + def __init__(self, token_model, realm=None, **extra_attributes): self.token_model = token_model - super(BearerTokenValidator, self).__init__(realm) + super().__init__(realm, **extra_attributes) def authenticate_token(self, token_string): try: @@ -63,12 +66,6 @@ def authenticate_token(self, token_string): except self.token_model.DoesNotExist: return None - def request_invalid(self, request): - return False - - def token_revoked(self, token): - return token.revoked - def return_error_response(error): body = dict(error.get_body()) diff --git a/authlib/integrations/django_oauth2/signals.py b/authlib/integrations/django_oauth2/signals.py index 0e9c2659b..5d22216fe 100644 --- a/authlib/integrations/django_oauth2/signals.py +++ b/authlib/integrations/django_oauth2/signals.py @@ -1,6 +1,5 @@ from django.dispatch import Signal - #: signal when client is authenticated client_authenticated = Signal() diff --git a/authlib/integrations/flask_client/__init__.py b/authlib/integrations/flask_client/__init__.py index 9aa6f7138..d6404acf8 100644 --- a/authlib/integrations/flask_client/__init__.py +++ b/authlib/integrations/flask_client/__init__.py @@ -1,11 +1,59 @@ -# flake8: noqa +from werkzeug.local import LocalProxy -from .oauth_registry import OAuth -from .remote_app import FlaskRemoteApp -from .integration import token_update, FlaskIntegration +from ..base_client import BaseOAuth from ..base_client import OAuthError +from .apps import FlaskOAuth1App +from .apps import FlaskOAuth2App +from .integration import FlaskIntegration +from .integration import token_update + + +class OAuth(BaseOAuth): + oauth1_client_cls = FlaskOAuth1App + oauth2_client_cls = FlaskOAuth2App + framework_integration_cls = FlaskIntegration + + def __init__(self, app=None, cache=None, fetch_token=None, update_token=None): + super().__init__( + cache=cache, fetch_token=fetch_token, update_token=update_token + ) + self.app = app + if app: + self.init_app(app) + + def init_app(self, app, cache=None, fetch_token=None, update_token=None): + """Initialize lazy for Flask app. This is usually used for Flask application + factory pattern. + """ + self.app = app + if cache is not None: + self.cache = cache + + if fetch_token: + self.fetch_token = fetch_token + if update_token: + self.update_token = update_token + + app.extensions = getattr(app, "extensions", {}) + app.extensions["authlib.integrations.flask_client"] = self + + def create_client(self, name): + if not self.app: + raise RuntimeError("OAuth is not init with Flask app.") + return super().create_client(name) + + def register(self, name, overwrite=False, **kwargs): + self._registry[name] = (overwrite, kwargs) + if self.app: + return self.create_client(name) + return LocalProxy(lambda: self.create_client(name)) + __all__ = [ - 'OAuth', 'FlaskRemoteApp', 'FlaskIntegration', - 'token_update', 'OAuthError', + "OAuth", + "FlaskIntegration", + "FlaskOAuth1App", + "FlaskOAuth2App", + "token_update", + "OAuthError", ] diff --git a/authlib/integrations/flask_client/apps.py b/authlib/integrations/flask_client/apps.py new file mode 100644 index 000000000..81cae167c --- /dev/null +++ b/authlib/integrations/flask_client/apps.py @@ -0,0 +1,164 @@ +from flask import g +from flask import redirect +from flask import request +from flask import session + +from ..base_client import BaseApp +from ..base_client import OAuth1Mixin +from ..base_client import OAuth2Mixin +from ..base_client import OAuthError +from ..base_client import OpenIDMixin +from ..requests_client import OAuth1Session +from ..requests_client import OAuth2Session + + +class FlaskAppMixin: + @property + def token(self): + attr = f"_oauth_token_{self.name}" + token = g.get(attr) + if token: + return token + if self._fetch_token: + token = self._fetch_token() + self.token = token + return token + + @token.setter + def token(self, token): + attr = f"_oauth_token_{self.name}" + setattr(g, attr, token) + + def _get_requested_token(self, *args, **kwargs): + return self.token + + def save_authorize_data(self, **kwargs): + state = kwargs.pop("state", None) + if state: + self.framework.set_state_data(session, state, kwargs) + else: + raise RuntimeError("Missing state value") + + def authorize_redirect(self, redirect_uri=None, **kwargs): + """Create a HTTP Redirect for Authorization Endpoint. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: A HTTP redirect response. + """ + rv = self.create_authorization_url(redirect_uri, **kwargs) + self.save_authorize_data(redirect_uri=redirect_uri, **rv) + return redirect(rv["url"]) + + +class FlaskOAuth1App(FlaskAppMixin, OAuth1Mixin, BaseApp): + client_cls = OAuth1Session + + def authorize_access_token(self, **kwargs): + """Fetch access token in one step. + + :return: A token dict. + """ + params = request.args.to_dict(flat=True) + state = params.get("oauth_token") + if not state: + raise OAuthError(description='Missing "oauth_token" parameter') + + data = self.framework.get_state_data(session, state) + if not data: + raise OAuthError(description='Missing "request_token" in temporary data') + + params["request_token"] = data["request_token"] + params.update(kwargs) + self.framework.clear_state_data(session, state) + token = self.fetch_access_token(**params) + self.token = token + return token + + +class FlaskOAuth2App(FlaskAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp): + client_cls = OAuth2Session + + def logout_redirect( + self, post_logout_redirect_uri=None, id_token_hint=None, **kwargs + ): + """Create a HTTP Redirect for End Session Endpoint (RP-Initiated Logout). + + :param post_logout_redirect_uri: URI to redirect after logout. + :param id_token_hint: ID Token previously issued to the RP. + :param kwargs: Extra parameters (state, client_id, logout_hint, ui_locales). + :return: A HTTP redirect response. + """ + result = self.create_logout_url( + post_logout_redirect_uri=post_logout_redirect_uri, + id_token_hint=id_token_hint, + **kwargs, + ) + if result.get("state"): + self.framework.set_state_data( + session, + result["state"], + { + "post_logout_redirect_uri": post_logout_redirect_uri, + }, + ) + return redirect(result["url"]) + + def validate_logout_response(self): + """Validate the state parameter from the logout callback. + + :return: The state data dict. + :raises OAuthError: If state is missing or invalid. + """ + state = request.args.get("state") + if not state: + raise OAuthError(description='Missing "state" parameter') + + state_data = self.framework.get_state_data(session, state) + if not state_data: + raise OAuthError(description='Invalid "state" parameter') + + self.framework.clear_state_data(session, state) + return state_data + + def authorize_access_token(self, **kwargs): + """Fetch access token in one step. + + :return: A token dict. + """ + if request.method == "GET": + error = request.args.get("error") + if error: + description = request.args.get("error_description") + raise OAuthError(error=error, description=description) + + params = { + "code": request.args.get("code"), + "state": request.args.get("state"), + } + else: + params = { + "code": request.form.get("code"), + "state": request.form.get("state"), + } + + state_data = self.framework.get_state_data(session, params.get("state")) + self.framework.clear_state_data(session, params.get("state")) + params = self._format_state_params(state_data, params) + + claims_options = kwargs.pop("claims_options", None) + claims_cls = kwargs.pop("claims_cls", None) + leeway = kwargs.pop("leeway", 120) + token = self.fetch_access_token(**params, **kwargs) + self.token = token + + if "id_token" in token and "nonce" in state_data: + userinfo = self.parse_id_token( + token, + nonce=state_data["nonce"], + claims_options=claims_options, + claims_cls=claims_cls, + leeway=leeway, + ) + token["userinfo"] = userinfo + return token diff --git a/authlib/integrations/flask_client/integration.py b/authlib/integrations/flask_client/integration.py index 55a3f861b..e5fe3cbb0 100644 --- a/authlib/integrations/flask_client/integration.py +++ b/authlib/integrations/flask_client/integration.py @@ -1,55 +1,28 @@ -from flask import current_app, session +from flask import current_app from flask.signals import Namespace + from ..base_client import FrameworkIntegration -from ..requests_client import OAuth1Session, OAuth2Session _signal = Namespace() #: signal when token is updated -token_update = _signal.signal('token_update') +token_update = _signal.signal("token_update") class FlaskIntegration(FrameworkIntegration): - oauth1_client_cls = OAuth1Session - oauth2_client_cls = OAuth2Session - - def set_session_data(self, request, key, value): - sess_key = '_{}_authlib_{}_'.format(self.name, key) - session[sess_key] = value - - def get_session_data(self, request, key): - sess_key = '_{}_authlib_{}_'.format(self.name, key) - return session.pop(sess_key, None) - def update_token(self, token, refresh_token=None, access_token=None): token_update.send( - current_app, + current_app._get_current_object(), name=self.name, token=token, refresh_token=refresh_token, access_token=access_token, ) - def generate_access_token_params(self, request_token_url, request): - if request_token_url: - return request.args.to_dict(flat=True) - - if request.method == 'GET': - params = { - 'code': request.args['code'], - 'state': request.args.get('state'), - } - else: - params = { - 'code': request.form['code'], - 'state': request.form.get('state'), - } - return params - @staticmethod def load_config(oauth, name, params): rv = {} for k in params: - conf_key = '{}_{}'.format(name, k).upper() + conf_key = f"{name}_{k}".upper() v = oauth.app.config.get(conf_key, None) if v is not None: rv[k] = v diff --git a/authlib/integrations/flask_client/oauth_registry.py b/authlib/integrations/flask_client/oauth_registry.py deleted file mode 100644 index 8f5d1fe3e..000000000 --- a/authlib/integrations/flask_client/oauth_registry.py +++ /dev/null @@ -1,118 +0,0 @@ -import uuid -from flask import session -from werkzeug.local import LocalProxy -from .integration import FlaskIntegration -from .remote_app import FlaskRemoteApp -from ..base_client import BaseOAuth - -__all__ = ['OAuth'] -_req_token_tpl = '_{}_authlib_req_token_' - - -class OAuth(BaseOAuth): - """A Flask OAuth registry for oauth clients. - - Create an instance with Flask:: - - oauth = OAuth(app, cache=cache) - - You can also pass the instance of Flask later:: - - oauth = OAuth() - oauth.init_app(app, cache=cache) - - :param app: Flask application instance - :param cache: A cache instance that has .get .set and .delete methods - :param fetch_token: a shared function to get current user's token - :param update_token: a share function to update current user's token - """ - framework_client_cls = FlaskRemoteApp - framework_integration_cls = FlaskIntegration - - def __init__(self, app=None, cache=None, fetch_token=None, update_token=None): - super(OAuth, self).__init__(fetch_token, update_token) - - self.app = app - self.cache = cache - if app: - self.init_app(app) - - def init_app(self, app, cache=None, fetch_token=None, update_token=None): - """Initialize lazy for Flask app. This is usually used for Flask application - factory pattern. - """ - self.app = app - if cache is not None: - self.cache = cache - - if fetch_token: - self.fetch_token = fetch_token - if update_token: - self.update_token = update_token - - app.extensions = getattr(app, 'extensions', {}) - app.extensions['authlib.integrations.flask_client'] = self - - def create_client(self, name): - if not self.app: - raise RuntimeError('OAuth is not init with Flask app.') - return super(OAuth, self).create_client(name) - - def register(self, name, overwrite=False, **kwargs): - self._registry[name] = (overwrite, kwargs) - if self.app: - return self.create_client(name) - return LocalProxy(lambda: self.create_client(name)) - - def generate_client_kwargs(self, name, overwrite, **kwargs): - kwargs = super(OAuth, self).generate_client_kwargs(name, overwrite, **kwargs) - - if kwargs.get('request_token_url'): - if self.cache: - _add_cache_request_token(self.cache, name, kwargs) - else: - _add_session_request_token(name, kwargs) - return kwargs - - -def _add_cache_request_token(cache, name, kwargs): - if not kwargs.get('fetch_request_token'): - def fetch_request_token(): - key = _req_token_tpl.format(name) - sid = session.pop(key, None) - if not sid: - return None - - token = cache.get(sid) - cache.delete(sid) - return token - - kwargs['fetch_request_token'] = fetch_request_token - - if not kwargs.get('save_request_token'): - def save_request_token(token): - key = _req_token_tpl.format(name) - sid = uuid.uuid4().hex - session[key] = sid - cache.set(sid, token, 600) - - kwargs['save_request_token'] = save_request_token - return kwargs - - -def _add_session_request_token(name, kwargs): - if not kwargs.get('fetch_request_token'): - def fetch_request_token(): - key = _req_token_tpl.format(name) - return session.pop(key, None) - - kwargs['fetch_request_token'] = fetch_request_token - - if not kwargs.get('save_request_token'): - def save_request_token(token): - key = _req_token_tpl.format(name) - session[key] = token - - kwargs['save_request_token'] = save_request_token - - return kwargs diff --git a/authlib/integrations/flask_client/remote_app.py b/authlib/integrations/flask_client/remote_app.py deleted file mode 100644 index 0d8eecb73..000000000 --- a/authlib/integrations/flask_client/remote_app.py +++ /dev/null @@ -1,81 +0,0 @@ -from flask import redirect -from flask import request as flask_req -from flask import _app_ctx_stack -from ..base_client import RemoteApp - - -class FlaskRemoteApp(RemoteApp): - """Flask integrated RemoteApp of :class:`~authlib.client.OAuthClient`. - It has built-in hooks for OAuthClient. The only required configuration - is token model. - """ - - def __init__(self, framework, name=None, fetch_token=None, **kwargs): - fetch_request_token = kwargs.pop('fetch_request_token', None) - save_request_token = kwargs.pop('save_request_token', None) - super(FlaskRemoteApp, self).__init__(framework, name, fetch_token, **kwargs) - - self._fetch_request_token = fetch_request_token - self._save_request_token = save_request_token - - def _on_update_token(self, token, refresh_token=None, access_token=None): - self.token = token - super(FlaskRemoteApp, self)._on_update_token( - token, refresh_token, access_token - ) - - @property - def token(self): - ctx = _app_ctx_stack.top - attr = 'authlib_oauth_token_{}'.format(self.name) - token = getattr(ctx, attr, None) - if token: - return token - if self._fetch_token: - token = self._fetch_token() - self.token = token - return token - - @token.setter - def token(self, token): - ctx = _app_ctx_stack.top - attr = 'authlib_oauth_token_{}'.format(self.name) - setattr(ctx, attr, token) - - def request(self, method, url, token=None, **kwargs): - if token is None and not kwargs.get('withhold_token'): - token = self.token - return super(FlaskRemoteApp, self).request( - method, url, token=token, **kwargs) - - def authorize_redirect(self, redirect_uri=None, **kwargs): - """Create a HTTP Redirect for Authorization Endpoint. - - :param redirect_uri: Callback or redirect URI for authorization. - :param kwargs: Extra parameters to include. - :return: A HTTP redirect response. - """ - rv = self.create_authorization_url(redirect_uri, **kwargs) - - if self.request_token_url: - request_token = rv.pop('request_token', None) - self._save_request_token(request_token) - - self.save_authorize_data(flask_req, redirect_uri=redirect_uri, **rv) - return redirect(rv['url']) - - def authorize_access_token(self, **kwargs): - """Authorize access token.""" - if self.request_token_url: - request_token = self._fetch_request_token() - else: - request_token = None - - params = self.retrieve_access_token_params(flask_req, request_token) - params.update(kwargs) - token = self.fetch_access_token(**params) - self.token = token - return token - - def parse_id_token(self, token, claims_options=None, leeway=120): - return self._parse_id_token(flask_req, token, claims_options, leeway) diff --git a/authlib/integrations/flask_helpers.py b/authlib/integrations/flask_helpers.py deleted file mode 100644 index 6883e4b68..000000000 --- a/authlib/integrations/flask_helpers.py +++ /dev/null @@ -1,25 +0,0 @@ -from flask import request as flask_req -from authlib.common.encoding import to_unicode - - -def create_oauth_request(request, request_cls, use_json=False): - if isinstance(request, request_cls): - return request - - if not request: - request = flask_req - - if request.method == 'POST': - if use_json: - body = request.get_json() - else: - body = request.form.to_dict(flat=True) - else: - body = None - - # query string in werkzeug Request.url is very weird - # scope=profile%20email will be scope=profile email - url = request.base_url - if request.query_string: - url = url + '?' + to_unicode(request.query_string) - return request_cls(request.method, url, body, request.headers) diff --git a/authlib/integrations/flask_oauth1/__init__.py b/authlib/integrations/flask_oauth1/__init__.py index 780b05945..dd20d9201 100644 --- a/authlib/integrations/flask_oauth1/__init__.py +++ b/authlib/integrations/flask_oauth1/__init__.py @@ -1,9 +1,8 @@ # flake8: noqa from .authorization_server import AuthorizationServer -from .resource_protector import ResourceProtector, current_credential -from .cache import ( - register_nonce_hooks, - register_temporary_credential_hooks, - create_exists_nonce_func, -) +from .cache import create_exists_nonce_func +from .cache import register_nonce_hooks +from .cache import register_temporary_credential_hooks +from .resource_protector import ResourceProtector +from .resource_protector import current_credential diff --git a/authlib/integrations/flask_oauth1/authorization_server.py b/authlib/integrations/flask_oauth1/authorization_server.py index 1062a7b17..8cf6afe07 100644 --- a/authlib/integrations/flask_oauth1/authorization_server.py +++ b/authlib/integrations/flask_oauth1/authorization_server.py @@ -1,13 +1,13 @@ import logging -from werkzeug.utils import import_string + from flask import Response -from authlib.oauth1 import ( - OAuth1Request, - AuthorizationServer as _AuthorizationServer, -) +from flask import request as flask_req +from werkzeug.utils import import_string + from authlib.common.security import generate_token from authlib.common.urls import url_encode -from ..flask_helpers import create_oauth_request +from authlib.oauth1 import AuthorizationServer as _AuthorizationServer +from authlib.oauth1 import OAuth1Request log = logging.getLogger(__name__) @@ -34,12 +34,12 @@ def __init__(self, app=None, query_client=None, token_generator=None): self.token_generator = token_generator self._hooks = { - 'exists_nonce': None, - 'create_temporary_credential': None, - 'get_temporary_credential': None, - 'delete_temporary_credential': None, - 'create_authorization_verifier': None, - 'create_token_credential': None, + "exists_nonce": None, + "create_temporary_credential": None, + "get_temporary_credential": None, + "delete_temporary_credential": None, + "create_authorization_verifier": None, + "create_token_credential": None, } if app is not None: self.init_app(app) @@ -53,7 +53,7 @@ def init_app(self, app, query_client=None, token_generator=None): if self.token_generator is None: self.token_generator = self.create_token_generator(app) - methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS') + methods = app.config.get("OAUTH1_SUPPORTED_SIGNATURE_METHODS") if methods and isinstance(methods, (list, tuple)): self.SUPPORTED_SIGNATURE_METHODS = methods @@ -65,37 +65,38 @@ def register_hook(self, name, func): self._hooks[name] = func def create_token_generator(self, app): - token_generator = app.config.get('OAUTH1_TOKEN_GENERATOR') + token_generator = app.config.get("OAUTH1_TOKEN_GENERATOR") if isinstance(token_generator, str): token_generator = import_string(token_generator) else: - length = app.config.get('OAUTH1_TOKEN_LENGTH', 42) + length = app.config.get("OAUTH1_TOKEN_LENGTH", 42) def token_generator(): return generate_token(length) - secret_generator = app.config.get('OAUTH1_TOKEN_SECRET_GENERATOR') + secret_generator = app.config.get("OAUTH1_TOKEN_SECRET_GENERATOR") if isinstance(secret_generator, str): secret_generator = import_string(secret_generator) else: - length = app.config.get('OAUTH1_TOKEN_SECRET_LENGTH', 48) + length = app.config.get("OAUTH1_TOKEN_SECRET_LENGTH", 48) def secret_generator(): return generate_token(length) def create_token(): return { - 'oauth_token': token_generator(), - 'oauth_token_secret': secret_generator() + "oauth_token": token_generator(), + "oauth_token_secret": secret_generator(), } + return create_token def get_client_by_id(self, client_id): return self.query_client(client_id) def exists_nonce(self, nonce, request): - func = self._hooks['exists_nonce'] + func = self._hooks["exists_nonce"] if callable(func): timestamp = request.timestamp client_id = request.client_id @@ -105,57 +106,43 @@ def exists_nonce(self, nonce, request): raise RuntimeError('"exists_nonce" hook is required.') def create_temporary_credential(self, request): - func = self._hooks['create_temporary_credential'] + func = self._hooks["create_temporary_credential"] if callable(func): token = self.token_generator() return func(token, request.client_id, request.redirect_uri) - raise RuntimeError( - '"create_temporary_credential" hook is required.' - ) + raise RuntimeError('"create_temporary_credential" hook is required.') def get_temporary_credential(self, request): - func = self._hooks['get_temporary_credential'] + func = self._hooks["get_temporary_credential"] if callable(func): return func(request.token) - raise RuntimeError( - '"get_temporary_credential" hook is required.' - ) + raise RuntimeError('"get_temporary_credential" hook is required.') def delete_temporary_credential(self, request): - func = self._hooks['delete_temporary_credential'] + func = self._hooks["delete_temporary_credential"] if callable(func): return func(request.token) - raise RuntimeError( - '"delete_temporary_credential" hook is required.' - ) + raise RuntimeError('"delete_temporary_credential" hook is required.') def create_authorization_verifier(self, request): - func = self._hooks['create_authorization_verifier'] + func = self._hooks["create_authorization_verifier"] if callable(func): verifier = generate_token(36) func(request.credential, request.user, verifier) return verifier - raise RuntimeError( - '"create_authorization_verifier" hook is required.' - ) + raise RuntimeError('"create_authorization_verifier" hook is required.') def create_token_credential(self, request): - func = self._hooks['create_token_credential'] + func = self._hooks["create_token_credential"] if callable(func): temporary_credential = request.credential token = self.token_generator() return func(token, temporary_credential) - raise RuntimeError( - '"create_token_credential" hook is required.' - ) - - def create_temporary_credentials_response(self, request=None): - return super(AuthorizationServer, self)\ - .create_temporary_credentials_response(request) + raise RuntimeError('"create_token_credential" hook is required.') def check_authorization_request(self): req = self.create_oauth1_request(None) @@ -163,18 +150,19 @@ def check_authorization_request(self): return req def create_authorization_response(self, request=None, grant_user=None): - return super(AuthorizationServer, self)\ - .create_authorization_response(request, grant_user) + return super().create_authorization_response(request, grant_user) def create_token_response(self, request=None): - return super(AuthorizationServer, self).create_token_response(request) + return super().create_token_response(request) def create_oauth1_request(self, request): - return create_oauth_request(request, OAuth1Request) + if request is None: + request = flask_req + if request.method in ("POST", "PUT"): + body = request.form.to_dict(flat=True) + else: + body = None + return OAuth1Request(request.method, request.url, body, request.headers) def handle_response(self, status_code, payload, headers): - return Response( - url_encode(payload), - status=status_code, - headers=headers - ) + return Response(url_encode(payload), status=status_code, headers=headers) diff --git a/authlib/integrations/flask_oauth1/cache.py b/authlib/integrations/flask_oauth1/cache.py index c22211baf..63f2951f3 100644 --- a/authlib/integrations/flask_oauth1/cache.py +++ b/authlib/integrations/flask_oauth1/cache.py @@ -2,7 +2,8 @@ def register_temporary_credential_hooks( - authorization_server, cache, key_prefix='temporary_credential:'): + authorization_server, cache, key_prefix="temporary_credential:" +): """Register temporary credential related hooks to authorization server. :param authorization_server: AuthorizationServer instance @@ -11,10 +12,10 @@ def register_temporary_credential_hooks( """ def create_temporary_credential(token, client_id, redirect_uri): - key = key_prefix + token['oauth_token'] - token['client_id'] = client_id + key = key_prefix + token["oauth_token"] + token["client_id"] = client_id if redirect_uri: - token['oauth_callback'] = redirect_uri + token["oauth_callback"] = redirect_uri cache.set(key, token, timeout=86400) # cache for one day return TemporaryCredential(token) @@ -34,22 +35,26 @@ def delete_temporary_credential(oauth_token): def create_authorization_verifier(credential, grant_user, verifier): key = key_prefix + credential.get_oauth_token() - credential['oauth_verifier'] = verifier - credential['user_id'] = grant_user.get_user_id() + credential["oauth_verifier"] = verifier + credential["user_id"] = grant_user.get_user_id() cache.set(key, credential, timeout=86400) return credential authorization_server.register_hook( - 'create_temporary_credential', create_temporary_credential) + "create_temporary_credential", create_temporary_credential + ) authorization_server.register_hook( - 'get_temporary_credential', get_temporary_credential) + "get_temporary_credential", get_temporary_credential + ) authorization_server.register_hook( - 'delete_temporary_credential', delete_temporary_credential) + "delete_temporary_credential", delete_temporary_credential + ) authorization_server.register_hook( - 'create_authorization_verifier', create_authorization_verifier) + "create_authorization_verifier", create_authorization_verifier + ) -def create_exists_nonce_func(cache, key_prefix='nonce:', expires=86400): +def create_exists_nonce_func(cache, key_prefix="nonce:", expires=86400): """Create an ``exists_nonce`` function that can be used in hooks and resource protector. @@ -57,18 +62,21 @@ def create_exists_nonce_func(cache, key_prefix='nonce:', expires=86400): :param key_prefix: key prefix for temporary credential :param expires: Expire time for nonce """ + def exists_nonce(nonce, timestamp, client_id, oauth_token): - key = '{}{}-{}-{}'.format(key_prefix, nonce, timestamp, client_id) + key = f"{key_prefix}{nonce}-{timestamp}-{client_id}" if oauth_token: - key = '{}-{}'.format(key, oauth_token) + key = f"{key}-{oauth_token}" rv = cache.has(key) cache.set(key, 1, timeout=expires) return rv + return exists_nonce def register_nonce_hooks( - authorization_server, cache, key_prefix='nonce:', expires=86400): + authorization_server, cache, key_prefix="nonce:", expires=86400 +): """Register nonce related hooks to authorization server. :param authorization_server: AuthorizationServer instance @@ -77,4 +85,4 @@ def register_nonce_hooks( :param expires: Expire time for nonce """ exists_nonce = create_exists_nonce_func(cache, key_prefix, expires) - authorization_server.register_hook('exists_nonce', exists_nonce) + authorization_server.register_hook("exists_nonce", exists_nonce) diff --git a/authlib/integrations/flask_oauth1/resource_protector.py b/authlib/integrations/flask_oauth1/resource_protector.py index 9f3361e15..10bd56c5e 100644 --- a/authlib/integrations/flask_oauth1/resource_protector.py +++ b/authlib/integrations/flask_oauth1/resource_protector.py @@ -1,8 +1,11 @@ import functools -from flask import json, Response + +from flask import Response +from flask import g +from flask import json from flask import request as _req -from flask import _app_ctx_stack from werkzeug.local import LocalProxy + from authlib.consts import default_json_headers from authlib.oauth1 import ResourceProtector as _ResourceProtector from authlib.oauth1.errors import OAuth1Error @@ -10,35 +13,43 @@ class ResourceProtector(_ResourceProtector): """A protecting method for resource servers. Initialize a resource - protector with the query_token method:: + protector with the these method: + + 1. query_client + 2. query_token, + 3. exists_nonce + + Usually, a ``query_client`` method would look like (if using SQLAlchemy):: + + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + + A ``query_token`` method accept two parameters, ``client_id`` and ``oauth_token``:: + + def query_token(client_id, oauth_token): + return Token.query.filter_by( + client_id=client_id, oauth_token=oauth_token + ).first() + + And for ``exists_nonce``, if using cache, we have a built-in hook to create this method:: - from authlib.integrations.flask_oauth1 import ResourceProtector, current_credential from authlib.integrations.flask_oauth1 import create_exists_nonce_func - from authlib.integrations.sqla_oauth1 import ( - create_query_client_func, - create_query_token_func, - ) - from your_project.models import Token, User, cache - # you need to define a ``cache`` instance yourself + exists_nonce = create_exists_nonce_func(cache) - require_oauth= ResourceProtector( - app, - query_client=create_query_client_func(db.session, OAuth1Client), - query_token=create_query_token_func(db.session, OAuth1Token), - exists_nonce=create_exists_nonce_func(cache) - ) - # or initialize it lazily - require_oauth = ResourceProtector() - require_oauth.init_app( + Then initialize the resource protector with those methods:: + + require_oauth = ResourceProtector( app, - query_client=create_query_client_func(db.session, OAuth1Client), - query_token=create_query_token_func(db.session, OAuth1Token), - exists_nonce=create_exists_nonce_func(cache) + query_client=query_client, + query_token=query_token, + exists_nonce=exists_nonce, ) """ - def __init__(self, app=None, query_client=None, - query_token=None, exists_nonce=None): + + def __init__( + self, app=None, query_client=None, query_token=None, exists_nonce=None + ): self.query_client = query_client self.query_token = query_token self._exists_nonce = exists_nonce @@ -47,8 +58,7 @@ def __init__(self, app=None, query_client=None, if app: self.init_app(app) - def init_app(self, app, query_client=None, query_token=None, - exists_nonce=None): + def init_app(self, app, query_client=None, query_token=None, exists_nonce=None): if query_client is not None: self.query_client = query_client if query_token is not None: @@ -56,7 +66,7 @@ def init_app(self, app, query_client=None, query_token=None, if exists_nonce is not None: self._exists_nonce = exists_nonce - methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS') + methods = app.config.get("OAUTH1_SUPPORTED_SIGNATURE_METHODS") if methods and isinstance(methods, (list, tuple)): self.SUPPORTED_SIGNATURE_METHODS = methods @@ -79,17 +89,13 @@ def exists_nonce(self, nonce, request): def acquire_credential(self): req = self.validate_request( - _req.method, - _req.url, - _req.form.to_dict(flat=True), - _req.headers + _req.method, _req.url, _req.form.to_dict(flat=True), _req.headers ) - ctx = _app_ctx_stack.top - ctx.authlib_server_oauth1_credential = req.credential + g.authlib_server_oauth1_credential = req.credential return req.credential def __call__(self, scope=None): - def wrapper(f): + def decorator(f): @functools.wraps(f) def decorated(*args, **kwargs): try: @@ -102,13 +108,16 @@ def decorated(*args, **kwargs): headers=default_json_headers, ) return f(*args, **kwargs) + return decorated - return wrapper + + if callable(scope): + return decorator(scope) + return decorator def _get_current_credential(): - ctx = _app_ctx_stack.top - return getattr(ctx, 'authlib_server_oauth1_credential', None) + return g.get("authlib_server_oauth1_credential") current_credential = LocalProxy(_get_current_credential) diff --git a/authlib/integrations/flask_oauth2/__init__.py b/authlib/integrations/flask_oauth2/__init__.py index 170a7190a..0ae826570 100644 --- a/authlib/integrations/flask_oauth2/__init__.py +++ b/authlib/integrations/flask_oauth2/__init__.py @@ -1,12 +1,8 @@ # flake8: noqa from .authorization_server import AuthorizationServer -from .resource_protector import ( - ResourceProtector, - current_token, -) -from .signals import ( - client_authenticated, - token_authenticated, - token_revoked, -) +from .resource_protector import ResourceProtector +from .resource_protector import current_token +from .signals import client_authenticated +from .signals import token_authenticated +from .signals import token_revoked diff --git a/authlib/integrations/flask_oauth2/authorization_server.py b/authlib/integrations/flask_oauth2/authorization_server.py index 7eb411c6c..8944c318c 100644 --- a/authlib/integrations/flask_oauth2/authorization_server.py +++ b/authlib/integrations/flask_oauth2/authorization_server.py @@ -1,17 +1,16 @@ +from flask import Response +from flask import json +from flask import request as flask_req from werkzeug.utils import import_string -from flask import Response, json -from authlib.deprecate import deprecate -from authlib.oauth2 import ( - OAuth2Request, - HttpRequest, - AuthorizationServer as _AuthorizationServer, -) -from authlib.oauth2.rfc6750 import BearerToken -from authlib.oauth2.rfc8414 import AuthorizationServerMetadata + from authlib.common.security import generate_token -from authlib.common.encoding import to_unicode -from .signals import client_authenticated, token_revoked -from ..flask_helpers import create_oauth_request +from authlib.oauth2 import AuthorizationServer as _AuthorizationServer +from authlib.oauth2.rfc6750 import BearerTokenGenerator + +from .requests import FlaskJsonRequest +from .requests import FlaskOAuth2Request +from .signals import client_authenticated +from .signals import token_revoked class AuthorizationServer(_AuthorizationServer): @@ -22,96 +21,63 @@ class AuthorizationServer(_AuthorizationServer): def query_client(client_id): return Client.query.filter_by(client_id=client_id).first() + def save_token(token, request): if request.user: - user_id = request.user.get_user_id() + user_id = request.user.id else: user_id = None client = request.client - tok = Token( - client_id=client.client_id, - user_id=user.get_user_id(), - **token - ) + tok = Token(client_id=client.client_id, user_id=user.id, **token) db.session.add(tok) db.session.commit() + server = AuthorizationServer(app, query_client, save_token) # or initialize lazily server = AuthorizationServer() server.init_app(app, query_client, save_token) """ - metadata_class = AuthorizationServerMetadata def __init__(self, app=None, query_client=None, save_token=None): - super(AuthorizationServer, self).__init__( - query_client=query_client, - save_token=save_token, - ) - self.config = {} + super().__init__() + self._query_client = query_client + self._save_token = save_token + self._error_uris = None if app is not None: self.init_app(app) def init_app(self, app, query_client=None, save_token=None): """Initialize later with Flask app instance.""" if query_client is not None: - self.query_client = query_client + self._query_client = query_client if save_token is not None: - self.save_token = save_token - - self.generate_token = self.create_bearer_token_generator(app.config) - - metadata_file = app.config.get('OAUTH2_METADATA_FILE') - if metadata_file: - with open(metadata_file) as f: - metadata = self.metadata_class(json.load(f)) - metadata.validate() - self.metadata = metadata - - self.config.setdefault('error_uris', app.config.get('OAUTH2_ERROR_URIS')) - if app.config.get('OAUTH2_JWT_ENABLED'): - deprecate('Define "get_jwt_config" in OpenID Connect grants', '1.0') - self.init_jwt_config(app.config) - - def init_jwt_config(self, config): - """Initialize JWT related configuration.""" - jwt_iss = config.get('OAUTH2_JWT_ISS') - if not jwt_iss: - raise RuntimeError('Missing "OAUTH2_JWT_ISS" configuration.') - - jwt_key_path = config.get('OAUTH2_JWT_KEY_PATH') - if jwt_key_path: - with open(jwt_key_path, 'r') as f: - if jwt_key_path.endswith('.json'): - jwt_key = json.load(f) - else: - jwt_key = to_unicode(f.read()) - else: - jwt_key = config.get('OAUTH2_JWT_KEY') - - if not jwt_key: - raise RuntimeError('Missing "OAUTH2_JWT_KEY" configuration.') - - jwt_alg = config.get('OAUTH2_JWT_ALG') - if not jwt_alg: - raise RuntimeError('Missing "OAUTH2_JWT_ALG" configuration.') - - jwt_exp = config.get('OAUTH2_JWT_EXP', 3600) - self.config.setdefault('jwt_iss', jwt_iss) - self.config.setdefault('jwt_key', jwt_key) - self.config.setdefault('jwt_alg', jwt_alg) - self.config.setdefault('jwt_exp', jwt_exp) - - def get_error_uris(self, request): - error_uris = self.config.get('error_uris') - if error_uris: - return dict(error_uris) + self._save_token = save_token + self.load_config(app.config) + + def load_config(self, config): + self.register_token_generator( + "default", self.create_bearer_token_generator(config) + ) + self.scopes_supported = config.get("OAUTH2_SCOPES_SUPPORTED") + self._error_uris = config.get("OAUTH2_ERROR_URIS") + + def query_client(self, client_id): + return self._query_client(client_id) + + def save_token(self, token, request): + return self._save_token(token, request) + + def get_error_uri(self, request, error): + if self._error_uris: + uris = dict(self._error_uris) + return uris.get(error.error) def create_oauth2_request(self, request): - return create_oauth_request(request, OAuth2Request) + return FlaskOAuth2Request(flask_req) def create_json_request(self, request): - return create_oauth_request(request, HttpRequest, True) + return FlaskJsonRequest(flask_req) def handle_response(self, status_code, payload, headers): if isinstance(payload, dict): @@ -119,83 +85,68 @@ def handle_response(self, status_code, payload, headers): return Response(payload, status=status_code, headers=headers) def send_signal(self, name, *args, **kwargs): - if name == 'after_authenticate_client': + if name == "after_authenticate_client": client_authenticated.send(self, *args, **kwargs) - elif name == 'after_revoke_token': + elif name == "after_revoke_token": token_revoked.send(self, *args, **kwargs) - def create_token_expires_in_generator(self, config): - """Create a generator function for generating ``expires_in`` value. - Developers can re-implement this method with a subclass if other means - required. The default expires_in value is defined by ``grant_type``, - different ``grant_type`` has different value. It can be configured - with:: - - OAUTH2_TOKEN_EXPIRES_IN = { - 'authorization_code': 864000, - 'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600, - } - """ - expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') - return create_token_expires_in_generator(expires_conf) - def create_bearer_token_generator(self, config): """Create a generator function for generating ``token`` value. This method will create a Bearer Token generator with - :class:`authlib.oauth2.rfc6750.BearerToken`. By default, it will not - generate ``refresh_token``, which can be turn on by configuration - ``OAUTH2_REFRESH_TOKEN_GENERATOR=True``. + :class:`authlib.oauth2.rfc6750.BearerToken`. + + Configurable settings: + + 1. OAUTH2_ACCESS_TOKEN_GENERATOR: Boolean or import string, default is True. + 2. OAUTH2_REFRESH_TOKEN_GENERATOR: Boolean or import string, default is False. + 3. OAUTH2_TOKEN_EXPIRES_IN: Dict or import string, default is None. + + By default, it will not generate ``refresh_token``, which can be turn on by + configure ``OAUTH2_REFRESH_TOKEN_GENERATOR``. + + Here are some examples of the token generator:: + + OAUTH2_ACCESS_TOKEN_GENERATOR = "your_project.generators.gen_token" + + # and in module `your_project.generators`, you can define: + + + def gen_token(client, grant_type, user, scope): + # generate token according to these parameters + token = create_random_token() + return f"{client.id}-{user.id}-{token}" + + Here is an example of ``OAUTH2_TOKEN_EXPIRES_IN``:: + + OAUTH2_TOKEN_EXPIRES_IN = { + "authorization_code": 864000, + "urn:ietf:params:oauth:grant-type:jwt-bearer": 3600, + } """ - conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True) + conf = config.get("OAUTH2_ACCESS_TOKEN_GENERATOR", True) access_token_generator = create_token_generator(conf, 42) - conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False) + conf = config.get("OAUTH2_REFRESH_TOKEN_GENERATOR", False) refresh_token_generator = create_token_generator(conf, 48) - expires_generator = self.create_token_expires_in_generator(config) - return BearerToken( - access_token_generator, - refresh_token_generator, - expires_generator + expires_conf = config.get("OAUTH2_TOKEN_EXPIRES_IN") + expires_generator = create_token_expires_in_generator(expires_conf) + return BearerTokenGenerator( + access_token_generator, refresh_token_generator, expires_generator ) - def validate_consent_request(self, request=None, end_user=None): - """Validate current HTTP request for authorization page. This page - is designed for resource owner to grant or deny the authorization:: - - @app.route('/authorize', methods=['GET']) - def authorize(): - try: - grant = server.validate_consent_request(end_user=current_user) - return render_template( - 'authorize.html', - grant=grant, - user=current_user - ) - except OAuth2Error as error: - return render_template( - 'error.html', - error=error - ) - """ - req = self.create_oauth2_request(request) - req.user = end_user - - grant = self.get_authorization_grant(req) - grant.validate_consent_request() - if not hasattr(grant, 'prompt'): - grant.prompt = None - return grant - def create_token_expires_in_generator(expires_in_conf=None): + if isinstance(expires_in_conf, str): + return import_string(expires_in_conf) + data = {} - data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) - if expires_in_conf: + data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN) + if isinstance(expires_in_conf, dict): data.update(expires_in_conf) def expires_in(client, grant_type): - return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) + return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN) return expires_in @@ -207,6 +158,8 @@ def create_token_generator(token_generator_conf, length=42): if isinstance(token_generator_conf, str): return import_string(token_generator_conf) elif token_generator_conf is True: + def token_generator(*args, **kwargs): return generate_token(length) + return token_generator diff --git a/authlib/integrations/flask_oauth2/errors.py b/authlib/integrations/flask_oauth2/errors.py index e9c9fdea4..5f499d119 100644 --- a/authlib/integrations/flask_oauth2/errors.py +++ b/authlib/integrations/flask_oauth2/errors.py @@ -1,19 +1,39 @@ +import importlib.metadata + from werkzeug.exceptions import HTTPException +_version = importlib.metadata.version("werkzeug").split(".")[0] + +if _version in ("0", "1"): + + class _HTTPException(HTTPException): + def __init__(self, code, body, headers, response=None): + super().__init__(None, response) + self.code = code + + self.body = body + self.headers = headers + + def get_body(self, environ=None): + return self.body + + def get_headers(self, environ=None): + return self.headers +else: -class _HTTPException(HTTPException): - def __init__(self, code, body, headers, response=None): - super(_HTTPException, self).__init__(None, response) - self.code = code + class _HTTPException(HTTPException): + def __init__(self, code, body, headers, response=None): + super().__init__(None, response) + self.code = code - self.body = body - self.headers = headers + self.body = body + self.headers = headers - def get_body(self, environ=None): - return self.body + def get_body(self, environ=None, scope=None): + return self.body - def get_headers(self, environ=None): - return self.headers + def get_headers(self, environ=None, scope=None): + return self.headers def raise_http_exception(status, body, headers): diff --git a/authlib/integrations/flask_oauth2/requests.py b/authlib/integrations/flask_oauth2/requests.py new file mode 100644 index 000000000..c09b41133 --- /dev/null +++ b/authlib/integrations/flask_oauth2/requests.py @@ -0,0 +1,57 @@ +from collections import defaultdict +from functools import cached_property + +from flask.wrappers import Request + +from authlib.oauth2.rfc6749 import JsonPayload +from authlib.oauth2.rfc6749 import JsonRequest +from authlib.oauth2.rfc6749 import OAuth2Payload +from authlib.oauth2.rfc6749 import OAuth2Request + + +class FlaskOAuth2Payload(OAuth2Payload): + def __init__(self, request: Request): + self._request = request + + @property + def data(self): + return self._request.values + + @cached_property + def datalist(self): + values = defaultdict(list) + for k in self.data: + values[k].extend(self.data.getlist(k)) + return values + + +class FlaskOAuth2Request(OAuth2Request): + def __init__(self, request: Request): + super().__init__( + method=request.method, uri=request.url, headers=request.headers + ) + self._request = request + self.payload = FlaskOAuth2Payload(request) + + @property + def args(self): + return self._request.args + + @property + def form(self): + return self._request.form + + +class FlaskJsonPayload(JsonPayload): + def __init__(self, request: Request): + self._request = request + + @property + def data(self): + return self._request.get_json() + + +class FlaskJsonRequest(JsonRequest): + def __init__(self, request: Request): + super().__init__(request.method, request.url, request.headers) + self.payload = FlaskJsonPayload(request) diff --git a/authlib/integrations/flask_oauth2/resource_protector.py b/authlib/integrations/flask_oauth2/resource_protector.py index 41535f35f..5f6c5e591 100644 --- a/authlib/integrations/flask_oauth2/resource_protector.py +++ b/authlib/integrations/flask_oauth2/resource_protector.py @@ -1,19 +1,18 @@ import functools from contextlib import contextmanager + +from flask import g from flask import json from flask import request as _req -from flask import _app_ctx_stack from werkzeug.local import LocalProxy -from authlib.oauth2 import ( - OAuth2Error, - ResourceProtector as _ResourceProtector -) -from authlib.oauth2.rfc6749 import ( - MissingAuthorizationError, - HttpRequest, -) -from .signals import token_authenticated + +from authlib.oauth2 import OAuth2Error +from authlib.oauth2 import ResourceProtector as _ResourceProtector +from authlib.oauth2.rfc6749 import MissingAuthorizationError + from .errors import raise_http_exception +from .requests import FlaskJsonRequest +from .signals import token_authenticated class ResourceProtector(_ResourceProtector): @@ -28,27 +27,25 @@ class ResourceProtector(_ResourceProtector): from authlib.oauth2.rfc6750 import BearerTokenValidator from project.models import Token + class MyBearerTokenValidator(BearerTokenValidator): def authenticate_token(self, token_string): return Token.query.filter_by(access_token=token_string).first() - def request_invalid(self, request): - return False - - def token_revoked(self, token): - return False require_oauth.register_token_validator(MyBearerTokenValidator()) # protect resource with require_oauth - @app.route('/user') - @require_oauth('profile') + + @app.route("/user") + @require_oauth(["profile"]) def user_profile(): - user = User.query.get(current_token.user_id) + user = User.get(current_token.user_id) return jsonify(user.to_dict()) """ + def raise_error_response(self, error): """Raise HTTPException for OAuth2Error. Developers can re-implement this method to customize the error response. @@ -61,49 +58,48 @@ def raise_error_response(self, error): headers = error.get_headers() raise_http_exception(status, body, headers) - def acquire_token(self, scope=None, operator='AND'): + def acquire_token(self, scopes=None, **kwargs): """A method to acquire current valid token with the given scope. - :param scope: string or list of scope values - :param operator: value of "AND" or "OR" + :param scopes: a list of scope values :return: token object """ - request = HttpRequest( - _req.method, - _req.full_path, - _req.data, - _req.headers - ) - if not callable(operator): - operator = operator.upper() - token = self.validate_request(scope, request, operator) + request = FlaskJsonRequest(_req) + # backward compatibility + kwargs["scopes"] = scopes + for claim in kwargs: + if isinstance(kwargs[claim], str): + kwargs[claim] = [kwargs[claim]] + token = self.validate_request(request=request, **kwargs) token_authenticated.send(self, token=token) - ctx = _app_ctx_stack.top - ctx.authlib_server_oauth2_token = token + g.authlib_server_oauth2_token = token return token @contextmanager - def acquire(self, scope=None, operator='AND'): + def acquire(self, scopes=None): """The with statement of ``require_oauth``. Instead of using a decorator, you can use a with statement instead:: - @app.route('/api/user') + @app.route("/api/user") def user_api(): - with require_oauth.acquire('profile') as token: - user = User.query.get(token.user_id) + with require_oauth.acquire("profile") as token: + user = User.get(token.user_id) return jsonify(user.to_dict()) """ try: - yield self.acquire_token(scope, operator) + yield self.acquire_token(scopes) except OAuth2Error as error: self.raise_error_response(error) - def __call__(self, scope=None, operator='AND', optional=False): - def wrapper(f): + def __call__(self, scopes=None, optional=False, **kwargs): + claims = kwargs + claims["scopes"] = scopes if not callable(scopes) else None + + def decorator(f): @functools.wraps(f) def decorated(*args, **kwargs): try: - self.acquire_token(scope, operator) + self.acquire_token(**claims) except MissingAuthorizationError as error: if optional: return f(*args, **kwargs) @@ -111,13 +107,16 @@ def decorated(*args, **kwargs): except OAuth2Error as error: self.raise_error_response(error) return f(*args, **kwargs) + return decorated - return wrapper + + if callable(scopes): + return decorator(scopes) + return decorator def _get_current_token(): - ctx = _app_ctx_stack.top - return getattr(ctx, 'authlib_server_oauth2_token', None) + return g.get("authlib_server_oauth2_token") current_token = LocalProxy(_get_current_token) diff --git a/authlib/integrations/flask_oauth2/signals.py b/authlib/integrations/flask_oauth2/signals.py index c61e0119d..f29ba1158 100644 --- a/authlib/integrations/flask_oauth2/signals.py +++ b/authlib/integrations/flask_oauth2/signals.py @@ -3,10 +3,10 @@ _signal = Namespace() #: signal when client is authenticated -client_authenticated = _signal.signal('client_authenticated') +client_authenticated = _signal.signal("client_authenticated") #: signal when token is revoked -token_revoked = _signal.signal('token_revoked') +token_revoked = _signal.signal("token_revoked") #: signal when token is authenticated -token_authenticated = _signal.signal('token_authenticated') +token_authenticated = _signal.signal("token_authenticated") diff --git a/authlib/integrations/httpx_client/__init__.py b/authlib/integrations/httpx_client/__init__.py index 6b4b9d677..006494121 100644 --- a/authlib/integrations/httpx_client/__init__.py +++ b/authlib/integrations/httpx_client/__init__.py @@ -1,25 +1,36 @@ -from authlib.oauth1 import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_RSA_SHA1, - SIGNATURE_PLAINTEXT, - SIGNATURE_TYPE_HEADER, - SIGNATURE_TYPE_QUERY, - SIGNATURE_TYPE_BODY, -) -from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client -from .oauth2_client import ( - OAuth2Auth, OAuth2Client, OAuth2ClientAuth, - AsyncOAuth2Client, -) -from .assertion_client import AssertionClient, AsyncAssertionClient -from ..base_client import OAuthError +from authlib.oauth1 import SIGNATURE_HMAC_SHA1 +from authlib.oauth1 import SIGNATURE_PLAINTEXT +from authlib.oauth1 import SIGNATURE_RSA_SHA1 +from authlib.oauth1 import SIGNATURE_TYPE_BODY +from authlib.oauth1 import SIGNATURE_TYPE_HEADER +from authlib.oauth1 import SIGNATURE_TYPE_QUERY +from ..base_client import OAuthError +from .assertion_client import AssertionClient +from .assertion_client import AsyncAssertionClient +from .oauth1_client import AsyncOAuth1Client +from .oauth1_client import OAuth1Auth +from .oauth1_client import OAuth1Client +from .oauth2_client import AsyncOAuth2Client +from .oauth2_client import OAuth2Auth +from .oauth2_client import OAuth2Client +from .oauth2_client import OAuth2ClientAuth __all__ = [ - 'OAuthError', - 'OAuth1Auth', 'AsyncOAuth1Client', - 'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT', - 'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY', - 'OAuth2Auth', 'OAuth2ClientAuth', 'AsyncOAuth2Client', - 'AsyncAssertionClient', + "OAuthError", + "OAuth1Auth", + "AsyncOAuth1Client", + "OAuth1Client", + "SIGNATURE_HMAC_SHA1", + "SIGNATURE_RSA_SHA1", + "SIGNATURE_PLAINTEXT", + "SIGNATURE_TYPE_HEADER", + "SIGNATURE_TYPE_QUERY", + "SIGNATURE_TYPE_BODY", + "OAuth2Auth", + "OAuth2ClientAuth", + "OAuth2Client", + "AsyncOAuth2Client", + "AssertionClient", + "AsyncAssertionClient", ] diff --git a/authlib/integrations/httpx_client/assertion_client.py b/authlib/integrations/httpx_client/assertion_client.py index 62f81b797..9d52dad81 100644 --- a/authlib/integrations/httpx_client/assertion_client.py +++ b/authlib/integrations/httpx_client/assertion_client.py @@ -1,86 +1,124 @@ -from httpx import AsyncClient, Client -from httpx._config import UNSET +import httpx +from httpx import USE_CLIENT_DEFAULT +from httpx import Response + from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient from authlib.oauth2.rfc7523 import JWTBearerGrant -from authlib.oauth2 import OAuth2Error -from .utils import extract_client_kwargs + +from ..base_client import OAuthError from .oauth2_client import OAuth2Auth +from .utils import extract_client_kwargs -__all__ = ['AsyncAssertionClient'] +__all__ = ["AsyncAssertionClient"] -class AsyncAssertionClient(_AssertionClient, AsyncClient): +class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient): token_auth_class = OAuth2Auth + oauth_error_class = OAuthError JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE ASSERTION_METHODS = { JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, } DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE - def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, **kwargs): - + def __init__( + self, + token_endpoint, + issuer, + subject, + audience=None, + grant_type=None, + claims=None, + token_placement="header", + scope=None, + **kwargs, + ): client_kwargs = extract_client_kwargs(kwargs) - AsyncClient.__init__(self, **client_kwargs) + httpx.AsyncClient.__init__(self, **client_kwargs) _AssertionClient.__init__( - self, session=None, - token_endpoint=token_endpoint, issuer=issuer, subject=subject, - audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + self, + session=None, + token_endpoint=token_endpoint, + issuer=issuer, + subject=subject, + audience=audience, + grant_type=grant_type, + claims=claims, + token_placement=token_placement, + scope=scope, + **kwargs, ) - async def request(self, method, url, withhold_token=False, auth=None, **kwargs): + async def request( + self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs + ) -> Response: """Send request with auto refresh token feature.""" - if not withhold_token and auth is UNSET: + if not withhold_token and auth is USE_CLIENT_DEFAULT: if not self.token or self.token.is_expired(): await self.refresh_token() auth = self.token_auth - return await super(AsyncAssertionClient, self).request( - method, url, auth=auth, **kwargs) + return await super().request(method, url, auth=auth, **kwargs) async def _refresh_token(self, data): resp = await self.request( - 'POST', self.token_endpoint, data=data, withhold_token=True) + "POST", self.token_endpoint, data=data, withhold_token=True + ) - token = resp.json() - if 'error' in token: - raise OAuth2Error( - error=token['error'], - description=token.get('error_description') - ) - self.token = token - return self.token + return self.parse_response_token(resp) -class AssertionClient(_AssertionClient, Client): +class AssertionClient(_AssertionClient, httpx.Client): token_auth_class = OAuth2Auth + oauth_error_class = OAuthError JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE ASSERTION_METHODS = { JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, } DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE - def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, **kwargs): - + def __init__( + self, + token_endpoint, + issuer, + subject, + audience=None, + grant_type=None, + claims=None, + token_placement="header", + scope=None, + **kwargs, + ): client_kwargs = extract_client_kwargs(kwargs) - Client.__init__(self, **client_kwargs) + # app keyword was dropped! + app_value = client_kwargs.pop("app", None) + if app_value is not None: + client_kwargs["transport"] = httpx.WSGITransport(app=app_value) + + httpx.Client.__init__(self, **client_kwargs) _AssertionClient.__init__( - self, session=self, - token_endpoint=token_endpoint, issuer=issuer, subject=subject, - audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + self, + session=self, + token_endpoint=token_endpoint, + issuer=issuer, + subject=subject, + audience=audience, + grant_type=grant_type, + claims=claims, + token_placement=token_placement, + scope=scope, + **kwargs, ) - def request(self, method, url, withhold_token=False, auth=None, **kwargs): + def request( + self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs + ): """Send request with auto refresh token feature.""" - if not withhold_token and auth is UNSET: + if not withhold_token and auth is USE_CLIENT_DEFAULT: if not self.token or self.token.is_expired(): self.refresh_token() auth = self.token_auth - return super(AssertionClient, self).request( - method, url, auth=auth, **kwargs) + return super().request(method, url, auth=auth, **kwargs) diff --git a/authlib/integrations/httpx_client/oauth1_client.py b/authlib/integrations/httpx_client/oauth1_client.py index 6d755bb1e..a47570707 100644 --- a/authlib/integrations/httpx_client/oauth1_client.py +++ b/authlib/integrations/httpx_client/oauth1_client.py @@ -1,47 +1,71 @@ import typing -from httpx import AsyncClient, Auth, Client, Request, Response -from authlib.oauth1 import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_TYPE_HEADER, -) + +import httpx +from httpx import Auth +from httpx import Request +from httpx import Response + from authlib.common.encoding import to_unicode +from authlib.oauth1 import SIGNATURE_HMAC_SHA1 +from authlib.oauth1 import SIGNATURE_TYPE_HEADER from authlib.oauth1 import ClientAuth from authlib.oauth1.client import OAuth1Client as _OAuth1Client -from .utils import extract_client_kwargs + from ..base_client import OAuthError +from .utils import build_request +from .utils import extract_client_kwargs class OAuth1Auth(Auth, ClientAuth): - """Signs the httpx request using OAuth 1 (RFC5849)""" + """Signs the httpx request using OAuth 1 (RFC5849).""" + requires_request_body = True def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: url, headers, body = self.prepare( - request.method, str(request.url), request.headers, request.content) - headers['Content-Length'] = str(len(body)) - yield Request(method=request.method, url=url, headers=headers, data=body) + request.method, str(request.url), request.headers, request.content + ) + headers["Content-Length"] = str(len(body)) + yield build_request( + url=url, headers=headers, body=body, initial_request=request + ) -class AsyncOAuth1Client(_OAuth1Client, AsyncClient): +class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient): auth_class = OAuth1Auth - def __init__(self, client_id, client_secret=None, - token=None, token_secret=None, - redirect_uri=None, rsa_key=None, verifier=None, - signature_method=SIGNATURE_HMAC_SHA1, - signature_type=SIGNATURE_TYPE_HEADER, - force_include_body=False, **kwargs): - + def __init__( + self, + client_id, + client_secret=None, + token=None, + token_secret=None, + redirect_uri=None, + rsa_key=None, + verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, + **kwargs, + ): _client_kwargs = extract_client_kwargs(kwargs) - AsyncClient.__init__(self, **_client_kwargs) + httpx.AsyncClient.__init__(self, **_client_kwargs) _OAuth1Client.__init__( - self, None, - client_id=client_id, client_secret=client_secret, - token=token, token_secret=token_secret, - redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, - signature_method=signature_method, signature_type=signature_type, - force_include_body=force_include_body, **kwargs) + self, + None, + client_id=client_id, + client_secret=client_secret, + token=token, + token_secret=token_secret, + redirect_uri=redirect_uri, + rsa_key=rsa_key, + verifier=verifier, + signature_method=signature_method, + signature_type=signature_type, + force_include_body=force_include_body, + **kwargs, + ) async def fetch_access_token(self, url, verifier=None, **kwargs): """Method for fetching an access token from the token endpoint. @@ -58,7 +82,7 @@ async def fetch_access_token(self, url, verifier=None, **kwargs): if verifier: self.auth.verifier = verifier if not self.auth.verifier: - self.handle_error('missing_verifier', 'Missing "verifier" value') + self.handle_error("missing_verifier", 'Missing "verifier" value') token = await self._fetch_token(url, **kwargs) self.auth.verifier = None return token @@ -74,26 +98,47 @@ async def _fetch_token(self, url, **kwargs): def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) -class OAuth1Client(_OAuth1Client, Client): - auth_class = OAuth1Auth - def __init__(self, client_id, client_secret=None, - token=None, token_secret=None, - redirect_uri=None, rsa_key=None, verifier=None, - signature_method=SIGNATURE_HMAC_SHA1, - signature_type=SIGNATURE_TYPE_HEADER, - force_include_body=False, **kwargs): +class OAuth1Client(_OAuth1Client, httpx.Client): + auth_class = OAuth1Auth + def __init__( + self, + client_id, + client_secret=None, + token=None, + token_secret=None, + redirect_uri=None, + rsa_key=None, + verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, + **kwargs, + ): _client_kwargs = extract_client_kwargs(kwargs) - Client.__init__(self, **_client_kwargs) + # app keyword was dropped! + app_value = _client_kwargs.pop("app", None) + if app_value is not None: + _client_kwargs["transport"] = httpx.WSGITransport(app=app_value) + + httpx.Client.__init__(self, **_client_kwargs) _OAuth1Client.__init__( - self, self, - client_id=client_id, client_secret=client_secret, - token=token, token_secret=token_secret, - redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, - signature_method=signature_method, signature_type=signature_type, - force_include_body=force_include_body, **kwargs) + self, + self, + client_id=client_id, + client_secret=client_secret, + token=token, + token_secret=token_secret, + redirect_uri=redirect_uri, + rsa_key=rsa_key, + verifier=verifier, + signature_method=signature_method, + signature_type=signature_type, + force_include_body=force_include_body, + **kwargs, + ) @staticmethod def handle_error(error_type, error_description): diff --git a/authlib/integrations/httpx_client/oauth2_client.py b/authlib/integrations/httpx_client/oauth2_client.py index 387560b98..a157b7eb7 100644 --- a/authlib/integrations/httpx_client/oauth2_client.py +++ b/authlib/integrations/httpx_client/oauth2_client.py @@ -1,36 +1,50 @@ -import asyncio import typing -from httpx import AsyncClient, Auth, Client, Request, Response -from httpx._config import UNSET +from contextlib import asynccontextmanager + +import httpx +from anyio import Lock # Import after httpx so import errors refer to httpx +from httpx import USE_CLIENT_DEFAULT +from httpx import Auth +from httpx import Request +from httpx import Response + from authlib.common.urls import url_decode +from authlib.oauth2.auth import ClientAuth +from authlib.oauth2.auth import TokenAuth from authlib.oauth2.client import OAuth2Client as _OAuth2Client -from authlib.oauth2.auth import ClientAuth, TokenAuth + +from ..base_client import InvalidTokenError +from ..base_client import MissingTokenError +from ..base_client import OAuthError +from ..base_client import UnsupportedTokenTypeError from .utils import HTTPX_CLIENT_KWARGS -from ..base_client import ( - OAuthError, - InvalidTokenError, - MissingTokenError, - UnsupportedTokenTypeError, -) +from .utils import build_request __all__ = [ - 'OAuth2Auth', 'OAuth2ClientAuth', - 'AsyncOAuth2Client', + "OAuth2Auth", + "OAuth2ClientAuth", + "AsyncOAuth2Client", + "OAuth2Client", ] class OAuth2Auth(Auth, TokenAuth): """Sign requests for OAuth 2.0, currently only bearer token is supported.""" + requires_request_body = True def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: try: url, headers, body = self.prepare( - str(request.url), request.headers, request.content) - yield Request(method=request.method, url=url, headers=headers, data=body) + str(request.url), request.headers, request.content + ) + headers["Content-Length"] = str(len(body)) + yield build_request( + url=url, headers=headers, body=body, initial_request=request + ) except KeyError as error: - description = 'Unsupported token_type: {}'.format(str(error)) - raise UnsupportedTokenTypeError(description=description) + description = f"Unsupported token_type: {str(error)}" + raise UnsupportedTokenTypeError(description=description) from error class OAuth2ClientAuth(Auth, ClientAuth): @@ -38,173 +52,234 @@ class OAuth2ClientAuth(Auth, ClientAuth): def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: url, headers, body = self.prepare( - request.method, str(request.url), request.headers, request.content) - yield Request(method=request.method, url=url, headers=headers, data=body) + request.method, str(request.url), request.headers, request.content + ) + headers["Content-Length"] = str(len(body)) + yield build_request( + url=url, headers=headers, body=body, initial_request=request + ) -class AsyncOAuth2Client(_OAuth2Client, AsyncClient): +class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient): SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS client_auth_class = OAuth2ClientAuth token_auth_class = OAuth2Auth - - def __init__(self, client_id=None, client_secret=None, - token_endpoint_auth_method=None, - revocation_endpoint_auth_method=None, - scope=None, redirect_uri=None, - token=None, token_placement='header', - update_token=None, **kwargs): - + oauth_error_class = OAuthError + + def __init__( + self, + client_id=None, + client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, + redirect_uri=None, + token=None, + token_placement="header", + update_token=None, + leeway=60, + **kwargs, + ): # extract httpx.Client kwargs client_kwargs = self._extract_session_request_params(kwargs) - AsyncClient.__init__(self, **client_kwargs) + httpx.AsyncClient.__init__(self, **client_kwargs) - # We use a "reverse" Event to synchronize coroutines to prevent + # We use a Lock to synchronize coroutines to prevent # multiple concurrent attempts to refresh the same token - self._token_refresh_event = asyncio.Event() - self._token_refresh_event.set() + self._token_refresh_lock = Lock() _OAuth2Client.__init__( - self, session=None, - client_id=client_id, client_secret=client_secret, + self, + session=None, + client_id=client_id, + client_secret=client_secret, token_endpoint_auth_method=token_endpoint_auth_method, revocation_endpoint_auth_method=revocation_endpoint_auth_method, - scope=scope, redirect_uri=redirect_uri, - token=token, token_placement=token_placement, - update_token=update_token, **kwargs + scope=scope, + redirect_uri=redirect_uri, + token=token, + token_placement=token_placement, + update_token=update_token, + leeway=leeway, + **kwargs, ) - @staticmethod - def handle_error(error_type, error_description): - raise OAuthError(error_type, error_description) + async def request( + self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs + ): + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token: + raise MissingTokenError() + + await self.ensure_active_token(self.token) - async def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and auth is UNSET: + auth = self.token_auth + + return await super().request(method, url, auth=auth, **kwargs) + + @asynccontextmanager + async def stream( + self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs + ): + if not withhold_token and auth is USE_CLIENT_DEFAULT: if not self.token: raise MissingTokenError() - if self.token.is_expired(): - await self.ensure_active_token() + await self.ensure_active_token(self.token) auth = self.token_auth - return await super(AsyncOAuth2Client, self).request( - method, url, auth=auth, **kwargs) - - async def ensure_active_token(self): - if self._token_refresh_event.is_set(): - # Unset the event so other coroutines don't try to update the token - self._token_refresh_event.clear() - refresh_token = self.token.get('refresh_token') - url = self.metadata.get('token_endpoint') - if refresh_token and url: - await self.refresh_token(url, refresh_token=refresh_token) - elif self.metadata.get('grant_type') == 'client_credentials': - access_token = self.token['access_token'] - token = await self.fetch_token(url, grant_type='client_credentials') - if self.update_token: - await self.update_token(token, access_token=access_token) - else: - raise InvalidTokenError() - # Notify coroutines that token is refreshed - self._token_refresh_event.set() - return - await self._token_refresh_event.wait() # wait until the token is ready - - async def _fetch_token(self, url, body='', headers=None, auth=None, - method='POST', **kwargs): - if method.upper() == 'POST': + async with super().stream(method, url, auth=auth, **kwargs) as resp: + yield resp + + async def ensure_active_token(self, token): + async with self._token_refresh_lock: + if self.token.is_expired(leeway=self.leeway): + refresh_token = token.get("refresh_token") + url = self.metadata.get("token_endpoint") + if refresh_token and url: + await self.refresh_token(url, refresh_token=refresh_token) + elif self.metadata.get("grant_type") == "client_credentials": + access_token = token["access_token"] + new_token = await self.fetch_token( + url, grant_type="client_credentials" + ) + if self.update_token: + await self.update_token(new_token, access_token=access_token) + else: + raise InvalidTokenError() + + async def _fetch_token( + self, + url, + body="", + headers=None, + auth=USE_CLIENT_DEFAULT, + method="POST", + **kwargs, + ): + if method.upper() == "POST": resp = await self.post( - url, data=dict(url_decode(body)), headers=headers, - auth=auth, **kwargs) + url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs + ) else: - if '?' in url: - url = '&'.join([url, body]) + if "?" in url: + url = "&".join([url, body]) else: - url = '?'.join([url, body]) + url = "?".join([url, body]) resp = await self.get(url, headers=headers, auth=auth, **kwargs) - for hook in self.compliance_hook['access_token_response']: + for hook in self.compliance_hook["access_token_response"]: resp = hook(resp) - return self.parse_response_token(resp.json()) - - async def _refresh_token(self, url, refresh_token=None, body='', - headers=None, auth=None, **kwargs): + return self.parse_response_token(resp) + + async def _refresh_token( + self, + url, + refresh_token=None, + body="", + headers=None, + auth=USE_CLIENT_DEFAULT, + **kwargs, + ): resp = await self.post( - url, data=dict(url_decode(body)), headers=headers, - auth=auth, **kwargs) + url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs + ) - for hook in self.compliance_hook['refresh_token_response']: + for hook in self.compliance_hook["refresh_token_response"]: resp = hook(resp) - token = self.parse_response_token(resp.json()) - if 'refresh_token' not in token: - self.token['refresh_token'] = refresh_token + token = self.parse_response_token(resp) + if "refresh_token" not in token: + self.token["refresh_token"] = refresh_token if self.update_token: await self.update_token(self.token, refresh_token=refresh_token) return self.token - def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): + def _http_post( + self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs + ): return self.post( - url, data=dict(url_decode(body)), - headers=headers, auth=auth, **kwargs) + url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs + ) + -class OAuth2Client(_OAuth2Client, Client): +class OAuth2Client(_OAuth2Client, httpx.Client): SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS client_auth_class = OAuth2ClientAuth token_auth_class = OAuth2Auth - - def __init__(self, client_id=None, client_secret=None, - token_endpoint_auth_method=None, - revocation_endpoint_auth_method=None, - scope=None, redirect_uri=None, - token=None, token_placement='header', - update_token=None, **kwargs): - + oauth_error_class = OAuthError + + def __init__( + self, + client_id=None, + client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, + redirect_uri=None, + token=None, + token_placement="header", + update_token=None, + **kwargs, + ): # extract httpx.Client kwargs client_kwargs = self._extract_session_request_params(kwargs) - Client.__init__(self, **client_kwargs) + # app keyword was dropped! + app_value = client_kwargs.pop("app", None) + if app_value is not None: + client_kwargs["transport"] = httpx.WSGITransport(app=app_value) + + httpx.Client.__init__(self, **client_kwargs) _OAuth2Client.__init__( - self, session=self, - client_id=client_id, client_secret=client_secret, + self, + session=self, + client_id=client_id, + client_secret=client_secret, token_endpoint_auth_method=token_endpoint_auth_method, revocation_endpoint_auth_method=revocation_endpoint_auth_method, - scope=scope, redirect_uri=redirect_uri, - token=token, token_placement=token_placement, - update_token=update_token, **kwargs + scope=scope, + redirect_uri=redirect_uri, + token=token, + token_placement=token_placement, + update_token=update_token, + **kwargs, ) @staticmethod def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) - def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and auth is UNSET: + def request( + self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs + ): + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token: + raise MissingTokenError() + + if not self.ensure_active_token(self.token): + raise InvalidTokenError() + + auth = self.token_auth + + return super().request(method, url, auth=auth, **kwargs) + + def stream( + self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs + ): + if not withhold_token and auth is USE_CLIENT_DEFAULT: if not self.token: raise MissingTokenError() - if self.token.is_expired(): - self.ensure_active_token() + if not self.ensure_active_token(self.token): + raise InvalidTokenError() auth = self.token_auth - return super(OAuth2Client, self).request( - method, url, auth=auth, **kwargs) - - def ensure_active_token(self): - refresh_token = self.token.get('refresh_token') - url = self.metadata.get('token_endpoint') - if refresh_token and url: - self.refresh_token(url, refresh_token=refresh_token) - elif self.metadata.get('grant_type') == 'client_credentials': - access_token = self.token['access_token'] - token = self.fetch_token(url, grant_type='client_credentials') - if self.update_token: - self.update_token(token, access_token=access_token) - else: - raise InvalidTokenError() + return super().stream(method, url, auth=auth, **kwargs) diff --git a/authlib/integrations/httpx_client/utils.py b/authlib/integrations/httpx_client/utils.py index 907aa6c8d..33c3a2fe6 100644 --- a/authlib/integrations/httpx_client/utils.py +++ b/authlib/integrations/httpx_client/utils.py @@ -1,8 +1,23 @@ +from httpx import Request + HTTPX_CLIENT_KWARGS = [ - 'headers', 'cookies', 'verify', 'cert', 'http_versions', - 'proxies', 'timeout', 'pool_limits', 'max_redirects', - 'base_url', 'dispatch', 'app', 'backend', 'trust_env', - 'json', + "headers", + "cookies", + "verify", + "cert", + "http1", + "http2", + "proxy", + "mounts", + "timeout", + "follow_redirects", + "limits", + "max_redirects", + "event_hooks", + "base_url", + "transport", + "trust_env", + "default_encoding", ] @@ -12,3 +27,15 @@ def extract_client_kwargs(kwargs): if k in kwargs: client_kwargs[k] = kwargs.pop(k) return client_kwargs + + +def build_request(url, headers, body, initial_request: Request) -> Request: + """Make sure that all the data from initial request is passed to the updated object.""" + updated_request = Request( + method=initial_request.method, url=url, headers=headers, content=body + ) + + if hasattr(initial_request, "extensions"): + updated_request.extensions = initial_request.extensions + + return updated_request diff --git a/authlib/integrations/requests_client/__init__.py b/authlib/integrations/requests_client/__init__.py index fcbdec320..c9c01df38 100644 --- a/authlib/integrations/requests_client/__init__.py +++ b/authlib/integrations/requests_client/__init__.py @@ -1,22 +1,28 @@ -from .oauth1_session import OAuth1Session, OAuth1Auth -from .oauth2_session import OAuth2Session, OAuth2Auth -from .assertion_session import AssertionSession -from ..base_client import OAuthError -from authlib.oauth1 import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_RSA_SHA1, - SIGNATURE_PLAINTEXT, - SIGNATURE_TYPE_HEADER, - SIGNATURE_TYPE_QUERY, - SIGNATURE_TYPE_BODY, -) +from authlib.oauth1 import SIGNATURE_HMAC_SHA1 +from authlib.oauth1 import SIGNATURE_PLAINTEXT +from authlib.oauth1 import SIGNATURE_RSA_SHA1 +from authlib.oauth1 import SIGNATURE_TYPE_BODY +from authlib.oauth1 import SIGNATURE_TYPE_HEADER +from authlib.oauth1 import SIGNATURE_TYPE_QUERY +from ..base_client import OAuthError +from .assertion_session import AssertionSession +from .oauth1_session import OAuth1Auth +from .oauth1_session import OAuth1Session +from .oauth2_session import OAuth2Auth +from .oauth2_session import OAuth2Session __all__ = [ - 'OAuthError', - 'OAuth1Session', 'OAuth1Auth', - 'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT', - 'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY', - 'OAuth2Session', 'OAuth2Auth', - 'AssertionSession', + "OAuthError", + "OAuth1Session", + "OAuth1Auth", + "SIGNATURE_HMAC_SHA1", + "SIGNATURE_RSA_SHA1", + "SIGNATURE_PLAINTEXT", + "SIGNATURE_TYPE_HEADER", + "SIGNATURE_TYPE_QUERY", + "SIGNATURE_TYPE_BODY", + "OAuth2Session", + "OAuth2Auth", + "AssertionSession", ] diff --git a/authlib/integrations/requests_client/assertion_session.py b/authlib/integrations/requests_client/assertion_session.py index 1b95ea2f4..ee046077b 100644 --- a/authlib/integrations/requests_client/assertion_session.py +++ b/authlib/integrations/requests_client/assertion_session.py @@ -1,13 +1,17 @@ from requests import Session -from authlib.deprecate import deprecate + from authlib.oauth2.rfc7521 import AssertionClient from authlib.oauth2.rfc7523 import JWTBearerGrant + from .oauth2_session import OAuth2Auth +from .utils import update_session_configure class AssertionAuth(OAuth2Auth): def ensure_active_token(self): - if not self.token or self.token.is_expired() and self.client: + if self.client and ( + not self.token or self.token.is_expired(self.client.leeway) + ): return self.client.refresh_token() @@ -17,6 +21,7 @@ class AssertionSession(AssertionClient, Session): .. _RFC7521: https://tools.ietf.org/html/rfc7521 """ + token_auth_class = AssertionAuth JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE ASSERTION_METHODS = { @@ -24,25 +29,42 @@ class AssertionSession(AssertionClient, Session): } DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE - def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, **kwargs): + def __init__( + self, + token_endpoint, + issuer, + subject, + audience=None, + grant_type=None, + claims=None, + token_placement="header", + scope=None, + default_timeout=None, + leeway=60, + **kwargs, + ): Session.__init__(self) - - token_url = kwargs.pop('token_url', None) - if token_url: - deprecate('Use "token_endpoint" instead of "token_url"', '1.0') - token_endpoint = token_url - + self.default_timeout = default_timeout + update_session_configure(self, kwargs) AssertionClient.__init__( - self, session=self, - token_endpoint=token_endpoint, issuer=issuer, subject=subject, - audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + self, + session=self, + token_endpoint=token_endpoint, + issuer=issuer, + subject=subject, + audience=audience, + grant_type=grant_type, + claims=claims, + token_placement=token_placement, + scope=scope, + leeway=leeway, + **kwargs, ) def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature.""" + if self.default_timeout: + kwargs.setdefault("timeout", self.default_timeout) if not withhold_token and auth is None: auth = self.token_auth - return super(AssertionSession, self).request( - method, url, auth=auth, **kwargs) + return super().request(method, url, auth=auth, **kwargs) diff --git a/authlib/integrations/requests_client/oauth1_session.py b/authlib/integrations/requests_client/oauth1_session.py index 26a12ac58..d9f5d3458 100644 --- a/authlib/integrations/requests_client/oauth1_session.py +++ b/authlib/integrations/requests_client/oauth1_session.py @@ -1,22 +1,21 @@ -# -*- coding: utf-8 -*- from requests import Session from requests.auth import AuthBase -from authlib.oauth1 import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_TYPE_HEADER, -) + from authlib.common.encoding import to_native +from authlib.oauth1 import SIGNATURE_HMAC_SHA1 +from authlib.oauth1 import SIGNATURE_TYPE_HEADER from authlib.oauth1 import ClientAuth from authlib.oauth1.client import OAuth1Client + from ..base_client import OAuthError +from .utils import update_session_configure class OAuth1Auth(AuthBase, ClientAuth): - """Signs the request using OAuth 1 (RFC5849)""" + """Signs the request using OAuth 1 (RFC5849).""" def __call__(self, req): - url, headers, body = self.prepare( - req.method, req.url, req.headers, req.body) + url, headers, body = self.prepare(req.method, req.url, req.headers, req.body) req.url = to_native(url) req.prepare_headers(headers) @@ -28,29 +27,46 @@ def __call__(self, req): class OAuth1Session(OAuth1Client, Session): auth_class = OAuth1Auth - def __init__(self, client_id, client_secret=None, - token=None, token_secret=None, - redirect_uri=None, rsa_key=None, verifier=None, - signature_method=SIGNATURE_HMAC_SHA1, - signature_type=SIGNATURE_TYPE_HEADER, - force_include_body=False, **kwargs): + def __init__( + self, + client_id, + client_secret=None, + token=None, + token_secret=None, + redirect_uri=None, + rsa_key=None, + verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, + **kwargs, + ): Session.__init__(self) + update_session_configure(self, kwargs) OAuth1Client.__init__( - self, session=self, - client_id=client_id, client_secret=client_secret, - token=token, token_secret=token_secret, - redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, - signature_method=signature_method, signature_type=signature_type, - force_include_body=force_include_body, **kwargs) + self, + session=self, + client_id=client_id, + client_secret=client_secret, + token=token, + token_secret=token_secret, + redirect_uri=redirect_uri, + rsa_key=rsa_key, + verifier=verifier, + signature_method=signature_method, + signature_type=signature_type, + force_include_body=force_include_body, + **kwargs, + ) def rebuild_auth(self, prepared_request, response): """When being redirected we should always strip Authorization header, since nonce may not be reused as per OAuth spec. """ - if 'Authorization' in prepared_request.headers: + if "Authorization" in prepared_request.headers: # If we get redirected to a new host, we should strip out # any authentication headers. - prepared_request.headers.pop('Authorization', True) + prepared_request.headers.pop("Authorization", True) prepared_request.prepare_auth(self.auth) @staticmethod diff --git a/authlib/integrations/requests_client/oauth2_session.py b/authlib/integrations/requests_client/oauth2_session.py index 835487d20..2bacb18da 100644 --- a/authlib/integrations/requests_client/oauth2_session.py +++ b/authlib/integrations/requests_client/oauth2_session.py @@ -1,49 +1,41 @@ from requests import Session from requests.auth import AuthBase + +from authlib.oauth2.auth import ClientAuth +from authlib.oauth2.auth import TokenAuth from authlib.oauth2.client import OAuth2Client -from authlib.oauth2.auth import ClientAuth, TokenAuth -from ..base_client import ( - OAuthError, - InvalidTokenError, - MissingTokenError, - UnsupportedTokenTypeError, -) -__all__ = ['OAuth2Session', 'OAuth2Auth'] +from ..base_client import InvalidTokenError +from ..base_client import MissingTokenError +from ..base_client import OAuthError +from ..base_client import UnsupportedTokenTypeError +from .utils import update_session_configure + +__all__ = ["OAuth2Session", "OAuth2Auth"] class OAuth2Auth(AuthBase, TokenAuth): """Sign requests for OAuth 2.0, currently only bearer token is supported.""" def ensure_active_token(self): - if self.client and self.token.is_expired(): - refresh_token = self.token.get('refresh_token') - client = self.client - url = client.metadata.get('token_endpoint') - if refresh_token and url: - client.refresh_token(url, refresh_token=refresh_token) - elif client.metadata.get('grant_type') == 'client_credentials': - access_token = self.token['access_token'] - token = client.fetch_token(url, grant_type='client_credentials') - if client.update_token: - client.update_token(token, access_token=access_token) - else: - raise InvalidTokenError() + if self.client and not self.client.ensure_active_token(self.token): + raise InvalidTokenError() def __call__(self, req): self.ensure_active_token() try: req.url, req.headers, req.body = self.prepare( - req.url, req.headers, req.body) + req.url, req.headers, req.body + ) except KeyError as error: - description = 'Unsupported token_type: {}'.format(str(error)) - raise UnsupportedTokenTypeError(description=description) + description = f"Unsupported token_type: {str(error)}" + raise UnsupportedTokenTypeError(description=description) from error return req class OAuth2ClientAuth(AuthBase, ClientAuth): - """Attaches OAuth Client Authentication to the given Request object. - """ + """Attaches OAuth Client Authentication to the given Request object.""" + def __call__(self, req): req.url, req.headers, req.body = self.prepare( req.method, req.url, req.headers, req.body @@ -66,6 +58,7 @@ class OAuth2Session(OAuth2Client, Session): :param revocation_endpoint_auth_method: client authentication method for revocation endpoint. :param scope: Scope that you needed to access user resources. + :param state: Shared secret to prevent CSRF attack. :param redirect_uri: Redirect URI you registered as callback. :param token: A dict of token attributes such as ``access_token``, ``token_type`` and ``expires_at``. @@ -73,30 +66,63 @@ class OAuth2Session(OAuth2Client, Session): values: "header", "body", "uri". :param update_token: A function for you to update token. It accept a :class:`OAuth2Token` as parameter. + :param leeway: Time window in seconds before the actual expiration of the + authentication token, that the token is considered expired and will + be refreshed. + :param default_timeout: If settled, every requests will have a default timeout. """ + client_auth_class = OAuth2ClientAuth token_auth_class = OAuth2Auth + oauth_error_class = OAuthError SESSION_REQUEST_PARAMS = ( - 'allow_redirects', 'timeout', 'cookies', 'files', - 'proxies', 'hooks', 'stream', 'verify', 'cert', 'json' + "allow_redirects", + "timeout", + "cookies", + "files", + "proxies", + "hooks", + "stream", + "verify", + "cert", + "json", ) - def __init__(self, client_id=None, client_secret=None, - token_endpoint_auth_method=None, - revocation_endpoint_auth_method=None, - scope=None, redirect_uri=None, - token=None, token_placement='header', - update_token=None, **kwargs): - + def __init__( + self, + client_id=None, + client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, + state=None, + redirect_uri=None, + token=None, + token_placement="header", + update_token=None, + leeway=60, + default_timeout=None, + **kwargs, + ): Session.__init__(self) + self.default_timeout = default_timeout + update_session_configure(self, kwargs) + OAuth2Client.__init__( - self, session=self, - client_id=client_id, client_secret=client_secret, + self, + session=self, + client_id=client_id, + client_secret=client_secret, token_endpoint_auth_method=token_endpoint_auth_method, revocation_endpoint_auth_method=revocation_endpoint_auth_method, - scope=scope, redirect_uri=redirect_uri, - token=token, token_placement=token_placement, - update_token=update_token, **kwargs + scope=scope, + state=state, + redirect_uri=redirect_uri, + token=token, + token_placement=token_placement, + update_token=update_token, + leeway=leeway, + **kwargs, ) def fetch_access_token(self, url=None, **kwargs): @@ -105,13 +131,10 @@ def fetch_access_token(self, url=None, **kwargs): def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature (if available).""" + if self.default_timeout: + kwargs.setdefault("timeout", self.default_timeout) if not withhold_token and auth is None: if not self.token: raise MissingTokenError() auth = self.token_auth - return super(OAuth2Session, self).request( - method, url, auth=auth, **kwargs) - - @staticmethod - def handle_error(error_type, error_description): - raise OAuthError(error_type, error_description) + return super().request(method, url, auth=auth, **kwargs) diff --git a/authlib/integrations/requests_client/utils.py b/authlib/integrations/requests_client/utils.py new file mode 100644 index 000000000..dc9670501 --- /dev/null +++ b/authlib/integrations/requests_client/utils.py @@ -0,0 +1,15 @@ +REQUESTS_SESSION_KWARGS = [ + "proxies", + "hooks", + "stream", + "verify", + "cert", + "max_redirects", + "trust_env", +] + + +def update_session_configure(session, kwargs): + for k in REQUESTS_SESSION_KWARGS: + if k in kwargs: + setattr(session, k, kwargs.pop(k)) diff --git a/authlib/integrations/sqla_oauth1/__init__.py b/authlib/integrations/sqla_oauth1/__init__.py deleted file mode 100644 index 75f7730bd..000000000 --- a/authlib/integrations/sqla_oauth1/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# flake8: noqa - -from .mixins import ( - OAuth1ClientMixin, - OAuth1TemporaryCredentialMixin, - OAuth1TimestampNonceMixin, - OAuth1TokenCredentialMixin, -) -from .functions import ( - create_query_client_func, - create_query_token_func, - register_temporary_credential_hooks, - create_exists_nonce_func, - register_nonce_hooks, - register_token_credential_hooks, - register_authorization_hooks, -) diff --git a/authlib/integrations/sqla_oauth1/functions.py b/authlib/integrations/sqla_oauth1/functions.py deleted file mode 100644 index 31bb48e87..000000000 --- a/authlib/integrations/sqla_oauth1/functions.py +++ /dev/null @@ -1,154 +0,0 @@ -def create_query_client_func(session, model_class): - """Create an ``query_client`` function that can be used in authorization - server and resource protector. - - :param session: SQLAlchemy session - :param model_class: Client class - """ - def query_client(client_id): - q = session.query(model_class) - return q.filter_by(client_id=client_id).first() - return query_client - - -def create_query_token_func(session, model_class): - """Create an ``query_token`` function that can be used in - resource protector. - - :param session: SQLAlchemy session - :param model_class: TokenCredential class - """ - def query_token(client_id, oauth_token): - q = session.query(model_class) - return q.filter_by( - client_id=client_id, oauth_token=oauth_token).first() - return query_token - - -def register_temporary_credential_hooks( - authorization_server, session, model_class): - """Register temporary credential related hooks to authorization server. - - :param authorization_server: AuthorizationServer instance - :param session: SQLAlchemy session - :param model_class: TemporaryCredential class - """ - - def create_temporary_credential(token, client_id, redirect_uri): - item = model_class( - client_id=client_id, - oauth_token=token['oauth_token'], - oauth_token_secret=token['oauth_token_secret'], - oauth_callback=redirect_uri, - ) - session.add(item) - session.commit() - return item - - def get_temporary_credential(oauth_token): - q = session.query(model_class).filter_by(oauth_token=oauth_token) - return q.first() - - def delete_temporary_credential(oauth_token): - q = session.query(model_class).filter_by(oauth_token=oauth_token) - q.delete(synchronize_session=False) - session.commit() - - def create_authorization_verifier(credential, grant_user, verifier): - credential.set_user_id(grant_user.get_user_id()) - credential.oauth_verifier = verifier - session.add(credential) - session.commit() - return credential - - authorization_server.register_hook( - 'create_temporary_credential', create_temporary_credential) - authorization_server.register_hook( - 'get_temporary_credential', get_temporary_credential) - authorization_server.register_hook( - 'delete_temporary_credential', delete_temporary_credential) - authorization_server.register_hook( - 'create_authorization_verifier', create_authorization_verifier) - - -def create_exists_nonce_func(session, model_class): - """Create an ``exists_nonce`` function that can be used in hooks and - resource protector. - - :param session: SQLAlchemy session - :param model_class: TimestampNonce class - """ - def exists_nonce(nonce, timestamp, client_id, oauth_token): - q = session.query(model_class.nonce).filter_by( - nonce=nonce, - timestamp=timestamp, - client_id=client_id, - ) - if oauth_token: - q = q.filter_by(oauth_token=oauth_token) - rv = q.first() - if rv: - return True - - item = model_class( - nonce=nonce, - timestamp=timestamp, - client_id=client_id, - oauth_token=oauth_token, - ) - session.add(item) - session.commit() - return False - return exists_nonce - - -def register_nonce_hooks(authorization_server, session, model_class): - """Register nonce related hooks to authorization server. - - :param authorization_server: AuthorizationServer instance - :param session: SQLAlchemy session - :param model_class: TimestampNonce class - """ - exists_nonce = create_exists_nonce_func(session, model_class) - authorization_server.register_hook('exists_nonce', exists_nonce) - - -def register_token_credential_hooks( - authorization_server, session, model_class): - """Register token credential related hooks to authorization server. - - :param authorization_server: AuthorizationServer instance - :param session: SQLAlchemy session - :param model_class: TokenCredential class - """ - def create_token_credential(token, temporary_credential): - item = model_class( - oauth_token=token['oauth_token'], - oauth_token_secret=token['oauth_token_secret'], - client_id=temporary_credential.get_client_id() - ) - item.set_user_id(temporary_credential.get_user_id()) - session.add(item) - session.commit() - return item - - authorization_server.register_hook( - 'create_token_credential', create_token_credential) - - -def register_authorization_hooks( - authorization_server, session, - token_credential_model, - temporary_credential_model=None, - timestamp_nonce_model=None): - - register_token_credential_hooks( - authorization_server, session, token_credential_model) - - if temporary_credential_model is not None: - register_temporary_credential_hooks( - authorization_server, session, temporary_credential_model) - - if timestamp_nonce_model is not None: - register_nonce_hooks( - authorization_server, session, timestamp_nonce_model) diff --git a/authlib/integrations/sqla_oauth1/mixins.py b/authlib/integrations/sqla_oauth1/mixins.py deleted file mode 100644 index a72dd012e..000000000 --- a/authlib/integrations/sqla_oauth1/mixins.py +++ /dev/null @@ -1,97 +0,0 @@ -from sqlalchemy import Column, UniqueConstraint -from sqlalchemy import String, Integer, Text -from authlib.oauth1 import ( - ClientMixin, - TemporaryCredentialMixin, - TokenCredentialMixin, -) - - -class OAuth1ClientMixin(ClientMixin): - client_id = Column(String(48), index=True) - client_secret = Column(String(120), nullable=False) - default_redirect_uri = Column(Text, nullable=False, default='') - - def get_default_redirect_uri(self): - return self.default_redirect_uri - - def get_client_secret(self): - return self.client_secret - - def get_rsa_public_key(self): - return None - - -class OAuth1TemporaryCredentialMixin(TemporaryCredentialMixin): - client_id = Column(String(48), index=True) - oauth_token = Column(String(84), unique=True, index=True) - oauth_token_secret = Column(String(84)) - oauth_verifier = Column(String(84)) - oauth_callback = Column(Text, default='') - - def get_user_id(self): - """A method to get the grant user information of this temporary - credential. For instance, grant user is stored in database on - ``user_id`` column:: - - def get_user_id(self): - return self.user_id - - :return: User ID - """ - if hasattr(self, 'user_id'): - return self.user_id - else: - raise NotImplementedError() - - def set_user_id(self, user_id): - if hasattr(self, 'user_id'): - setattr(self, 'user_id', user_id) - else: - raise NotImplementedError() - - def get_client_id(self): - return self.client_id - - def get_redirect_uri(self): - return self.oauth_callback - - def check_verifier(self, verifier): - return self.oauth_verifier == verifier - - def get_oauth_token(self): - return self.oauth_token - - def get_oauth_token_secret(self): - return self.oauth_token_secret - - -class OAuth1TimestampNonceMixin(object): - __table_args__ = ( - UniqueConstraint( - 'client_id', 'timestamp', 'nonce', 'oauth_token', - name='unique_nonce' - ), - ) - client_id = Column(String(48), nullable=False) - timestamp = Column(Integer, nullable=False) - nonce = Column(String(48), nullable=False) - oauth_token = Column(String(84)) - - -class OAuth1TokenCredentialMixin(TokenCredentialMixin): - client_id = Column(String(48), index=True) - oauth_token = Column(String(84), unique=True, index=True) - oauth_token_secret = Column(String(84)) - - def set_user_id(self, user_id): - if hasattr(self, 'user_id'): - setattr(self, 'user_id', user_id) - else: - raise NotImplementedError() - - def get_oauth_token(self): - return self.oauth_token - - def get_oauth_token_secret(self): - return self.oauth_token_secret diff --git a/authlib/integrations/sqla_oauth2/__init__.py b/authlib/integrations/sqla_oauth2/__init__.py index 1964aa1a5..e2f806aab 100644 --- a/authlib/integrations/sqla_oauth2/__init__.py +++ b/authlib/integrations/sqla_oauth2/__init__.py @@ -1,17 +1,19 @@ from .client_mixin import OAuth2ClientMixin -from .tokens_mixins import OAuth2AuthorizationCodeMixin, OAuth2TokenMixin -from .functions import ( - create_query_client_func, - create_save_token_func, - create_query_token_func, - create_revocation_endpoint, - create_bearer_token_validator, -) - +from .functions import create_bearer_token_validator +from .functions import create_query_client_func +from .functions import create_query_token_func +from .functions import create_revocation_endpoint +from .functions import create_save_token_func +from .tokens_mixins import OAuth2AuthorizationCodeMixin +from .tokens_mixins import OAuth2TokenMixin __all__ = [ - 'OAuth2ClientMixin', 'OAuth2AuthorizationCodeMixin', 'OAuth2TokenMixin', - 'create_query_client_func', 'create_save_token_func', - 'create_query_token_func', 'create_revocation_endpoint', - 'create_bearer_token_validator', + "OAuth2ClientMixin", + "OAuth2AuthorizationCodeMixin", + "OAuth2TokenMixin", + "create_query_client_func", + "create_save_token_func", + "create_query_token_func", + "create_revocation_endpoint", + "create_bearer_token_validator", ] diff --git a/authlib/integrations/sqla_oauth2/client_mixin.py b/authlib/integrations/sqla_oauth2/client_mixin.py index b88b4ad8c..c8835086d 100644 --- a/authlib/integrations/sqla_oauth2/client_mixin.py +++ b/authlib/integrations/sqla_oauth2/client_mixin.py @@ -1,7 +1,15 @@ -from sqlalchemy import Column, String, Text, Integer -from authlib.common.encoding import json_loads, json_dumps +import secrets + +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import Text + +from authlib.common.encoding import json_dumps +from authlib.common.encoding import json_loads from authlib.oauth2.rfc6749 import ClientMixin -from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.oauth2.rfc6749 import list_to_scope +from authlib.oauth2.rfc6749 import scope_to_list class OAuth2ClientMixin(ClientMixin): @@ -9,7 +17,7 @@ class OAuth2ClientMixin(ClientMixin): client_secret = Column(String(120)) client_id_issued_at = Column(Integer, nullable=False, default=0) client_secret_expires_at = Column(Integer, nullable=False, default=0) - _client_metadata = Column('client_metadata', Text) + _client_metadata = Column("client_metadata", Text) @property def client_info(self): @@ -27,79 +35,84 @@ def client_info(self): @property def client_metadata(self): - if 'client_metadata' in self.__dict__: - return self.__dict__['client_metadata'] + if "client_metadata" in self.__dict__: + return self.__dict__["client_metadata"] if self._client_metadata: data = json_loads(self._client_metadata) - self.__dict__['client_metadata'] = data + self.__dict__["client_metadata"] = data return data return {} def set_client_metadata(self, value): self._client_metadata = json_dumps(value) + if "client_metadata" in self.__dict__: + del self.__dict__["client_metadata"] @property def redirect_uris(self): - return self.client_metadata.get('redirect_uris', []) + return self.client_metadata.get("redirect_uris", []) @property def token_endpoint_auth_method(self): return self.client_metadata.get( - 'token_endpoint_auth_method', - 'client_secret_basic' + "token_endpoint_auth_method", "client_secret_basic" ) @property def grant_types(self): - return self.client_metadata.get('grant_types', []) + return self.client_metadata.get("grant_types", []) @property def response_types(self): - return self.client_metadata.get('response_types', []) + return self.client_metadata.get("response_types", []) @property def client_name(self): - return self.client_metadata.get('client_name') + return self.client_metadata.get("client_name") @property def client_uri(self): - return self.client_metadata.get('client_uri') + return self.client_metadata.get("client_uri") @property def logo_uri(self): - return self.client_metadata.get('logo_uri') + return self.client_metadata.get("logo_uri") @property def scope(self): - return self.client_metadata.get('scope', '') + return self.client_metadata.get("scope", "") @property def contacts(self): - return self.client_metadata.get('contacts', []) + return self.client_metadata.get("contacts", []) @property def tos_uri(self): - return self.client_metadata.get('tos_uri') + return self.client_metadata.get("tos_uri") @property def policy_uri(self): - return self.client_metadata.get('policy_uri') + return self.client_metadata.get("policy_uri") @property def jwks_uri(self): - return self.client_metadata.get('jwks_uri') + return self.client_metadata.get("jwks_uri") @property def jwks(self): - return self.client_metadata.get('jwks', []) + return self.client_metadata.get("jwks", []) @property def software_id(self): - return self.client_metadata.get('software_id') + return self.client_metadata.get("software_id") @property def software_version(self): - return self.client_metadata.get('software_version') + return self.client_metadata.get("software_version") + + @property + def id_token_signed_response_alg(self): + return self.client_metadata.get("id_token_signed_response_alg") def get_client_id(self): return self.client_id @@ -110,7 +123,7 @@ def get_default_redirect_uri(self): def get_allowed_scope(self, scope): if not scope: - return '' + return "" allowed = set(self.scope.split()) scopes = scope_to_list(scope) return list_to_scope([s for s in scopes if s in allowed]) @@ -118,14 +131,14 @@ def get_allowed_scope(self, scope): def check_redirect_uri(self, redirect_uri): return redirect_uri in self.redirect_uris - def has_client_secret(self): - return bool(self.client_secret) - def check_client_secret(self, client_secret): - return self.client_secret == client_secret + return secrets.compare_digest(self.client_secret, client_secret) - def check_token_endpoint_auth_method(self, method): - return self.token_endpoint_auth_method == method + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == "token": + return self.token_endpoint_auth_method == method + # TODO + return True def check_response_type(self, response_type): return response_type in self.response_types diff --git a/authlib/integrations/sqla_oauth2/functions.py b/authlib/integrations/sqla_oauth2/functions.py index f79337bfa..d10ab24e8 100644 --- a/authlib/integrations/sqla_oauth2/functions.py +++ b/authlib/integrations/sqla_oauth2/functions.py @@ -1,3 +1,6 @@ +import time + + def create_query_client_func(session, client_model): """Create an ``query_client`` function that can be used in authorization server. @@ -5,9 +8,11 @@ def create_query_client_func(session, client_model): :param session: SQLAlchemy session :param client_model: Client model class """ + def query_client(client_id): q = session.query(client_model) return q.filter_by(client_id=client_id).first() + return query_client @@ -18,19 +23,17 @@ def create_save_token_func(session, token_model): :param session: SQLAlchemy session :param token_model: Token model class """ + def save_token(token, request): if request.user: user_id = request.user.get_user_id() else: user_id = None client = request.client - item = token_model( - client_id=client.client_id, - user_id=user_id, - **token - ) + item = token_model(client_id=client.client_id, user_id=user_id, **token) session.add(item) session.commit() + return save_token @@ -41,18 +44,19 @@ def create_query_token_func(session, token_model): :param session: SQLAlchemy session :param token_model: Token model class """ - def query_token(token, token_type_hint, client): + + def query_token(token, token_type_hint): q = session.query(token_model) - q = q.filter_by(client_id=client.client_id, revoked=False) - if token_type_hint == 'access_token': + if token_type_hint == "access_token": return q.filter_by(access_token=token).first() - elif token_type_hint == 'refresh_token': + elif token_type_hint == "refresh_token": return q.filter_by(refresh_token=token).first() # without token_type_hint item = q.filter_by(access_token=token).first() if item: return item return q.filter_by(refresh_token=token).first() + return query_token @@ -64,14 +68,19 @@ def create_revocation_endpoint(session, token_model): :param token_model: Token model class """ from authlib.oauth2.rfc7009 import RevocationEndpoint + query_token = create_query_token_func(session, token_model) class _RevocationEndpoint(RevocationEndpoint): - def query_token(self, token, token_type_hint, client): - return query_token(token, token_type_hint, client) - - def revoke_token(self, token): - token.revoked = True + def query_token(self, token, token_type_hint): + return query_token(token, token_type_hint) + + def revoke_token(self, token, request): + now = int(time.time()) + hint = request.form.get("token_type_hint") + token.access_token_revoked_at = now + if hint != "access_token": + token.refresh_token_revoked_at = now session.add(token) session.commit() @@ -92,10 +101,4 @@ def authenticate_token(self, token_string): q = session.query(token_model) return q.filter_by(access_token=token_string).first() - def request_invalid(self, request): - return False - - def token_revoked(self, token): - return token.revoked - return _BearerTokenValidator diff --git a/authlib/integrations/sqla_oauth2/tokens_mixins.py b/authlib/integrations/sqla_oauth2/tokens_mixins.py index fcd28e26b..91808e358 100644 --- a/authlib/integrations/sqla_oauth2/tokens_mixins.py +++ b/authlib/integrations/sqla_oauth2/tokens_mixins.py @@ -1,22 +1,24 @@ import time -from sqlalchemy import Column, String, Boolean, Text, Integer -from authlib.oauth2.rfc6749 import ( - TokenMixin, - AuthorizationCodeMixin, -) + +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import Text + +from authlib.oauth2.rfc6749 import AuthorizationCodeMixin +from authlib.oauth2.rfc6749 import TokenMixin class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin): code = Column(String(120), unique=True, nullable=False) client_id = Column(String(48)) - redirect_uri = Column(Text, default='') - response_type = Column(Text, default='') - scope = Column(Text, default='') + redirect_uri = Column(Text, default="") + response_type = Column(Text, default="") + scope = Column(Text, default="") nonce = Column(Text) - auth_time = Column( - Integer, nullable=False, - default=lambda: int(time.time()) - ) + auth_time = Column(Integer, nullable=False, default=lambda: int(time.time())) + acr = Column(Text, nullable=True) + amr = Column(Text, nullable=True) code_challenge = Column(Text) code_challenge_method = Column(String(48)) @@ -33,6 +35,12 @@ def get_scope(self): def get_auth_time(self): return self.auth_time + def get_acr(self): + return self.acr + + def get_amr(self): + return self.amr.split() if self.amr else [] + def get_nonce(self): return self.nonce @@ -42,15 +50,14 @@ class OAuth2TokenMixin(TokenMixin): token_type = Column(String(40)) access_token = Column(String(255), unique=True, nullable=False) refresh_token = Column(String(255), index=True) - scope = Column(Text, default='') - revoked = Column(Boolean, default=False) - issued_at = Column( - Integer, nullable=False, default=lambda: int(time.time()) - ) + scope = Column(Text, default="") + issued_at = Column(Integer, nullable=False, default=lambda: int(time.time())) + access_token_revoked_at = Column(Integer, nullable=False, default=0) + refresh_token_revoked_at = Column(Integer, nullable=False, default=0) expires_in = Column(Integer, nullable=False, default=0) - def get_client_id(self): - return self.client_id + def check_client(self, client): + return self.client_id == client.get_client_id() def get_scope(self): return self.scope @@ -58,5 +65,12 @@ def get_scope(self): def get_expires_in(self): return self.expires_in - def get_expires_at(self): - return self.issued_at + self.expires_in + def is_revoked(self): + return self.access_token_revoked_at or self.refresh_token_revoked_at + + def is_expired(self): + if not self.expires_in: + return False + + expires_at = self.issued_at + self.expires_in + return expires_at < time.time() diff --git a/authlib/integrations/starlette_client/__init__.py b/authlib/integrations/starlette_client/__init__.py index c4dbe9fc6..e7d963782 100644 --- a/authlib/integrations/starlette_client/__init__.py +++ b/authlib/integrations/starlette_client/__init__.py @@ -1,20 +1,26 @@ -# flake8: noqa - -from ..base_client import BaseOAuth, OAuthError -from .integration import StartletteIntegration, StarletteRemoteApp +from ..base_client import BaseOAuth +from ..base_client import OAuthError +from .apps import StarletteOAuth1App +from .apps import StarletteOAuth2App +from .integration import StarletteIntegration class OAuth(BaseOAuth): - framework_client_cls = StarletteRemoteApp - framework_integration_cls = StartletteIntegration + oauth1_client_cls = StarletteOAuth1App + oauth2_client_cls = StarletteOAuth2App + framework_integration_cls = StarletteIntegration def __init__(self, config=None, cache=None, fetch_token=None, update_token=None): - super(OAuth, self).__init__(fetch_token, update_token) - self.cache = cache + super().__init__( + cache=cache, fetch_token=fetch_token, update_token=update_token + ) self.config = config __all__ = [ - 'OAuth', 'StartletteIntegration', 'StarletteRemoteApp', - 'OAuthError', + "OAuth", + "OAuthError", + "StarletteIntegration", + "StarletteOAuth1App", + "StarletteOAuth2App", ] diff --git a/authlib/integrations/starlette_client/apps.py b/authlib/integrations/starlette_client/apps.py new file mode 100644 index 000000000..e80def21a --- /dev/null +++ b/authlib/integrations/starlette_client/apps.py @@ -0,0 +1,145 @@ +from starlette.datastructures import URL +from starlette.responses import RedirectResponse + +from ..base_client import BaseApp +from ..base_client import OAuthError +from ..base_client.async_app import AsyncOAuth1Mixin +from ..base_client.async_app import AsyncOAuth2Mixin +from ..base_client.async_openid import AsyncOpenIDMixin +from ..httpx_client import AsyncOAuth1Client +from ..httpx_client import AsyncOAuth2Client + + +class StarletteAppMixin: + async def save_authorize_data(self, request, **kwargs): + state = kwargs.pop("state", None) + if state: + await self.framework.set_state_data(request.session, state, kwargs) + else: + raise RuntimeError("Missing state value") + + async def authorize_redirect(self, request, redirect_uri=None, **kwargs): + """Create a HTTP Redirect for Authorization Endpoint. + + :param request: HTTP request instance from Starlette view. + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: A HTTP redirect response. + """ + # Handle Starlette >= 0.26.0 where redirect_uri may now be a URL and not a string + if redirect_uri and isinstance(redirect_uri, URL): + redirect_uri = str(redirect_uri) + rv = await self.create_authorization_url(redirect_uri, **kwargs) + await self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) + return RedirectResponse(rv["url"], status_code=302) + + +class StarletteOAuth1App(StarletteAppMixin, AsyncOAuth1Mixin, BaseApp): + client_cls = AsyncOAuth1Client + + async def authorize_access_token(self, request, **kwargs): + params = dict(request.query_params) + state = params.get("oauth_token") + if not state: + raise OAuthError(description='Missing "oauth_token" parameter') + + data = await self.framework.get_state_data(request.session, state) + if not data: + raise OAuthError(description='Missing "request_token" in temporary data') + + params["request_token"] = data["request_token"] + params.update(kwargs) + await self.framework.clear_state_data(request.session, state) + return await self.fetch_access_token(**params) + + +class StarletteOAuth2App( + StarletteAppMixin, AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp +): + client_cls = AsyncOAuth2Client + + async def logout_redirect( + self, request, post_logout_redirect_uri=None, id_token_hint=None, **kwargs + ): + """Create a HTTP Redirect for End Session Endpoint (RP-Initiated Logout). + + :param request: HTTP request instance from Starlette view. + :param post_logout_redirect_uri: URI to redirect after logout. + :param id_token_hint: ID Token previously issued to the RP. + :param kwargs: Extra parameters (state, client_id, logout_hint, ui_locales). + :return: A HTTP redirect response. + """ + if post_logout_redirect_uri and isinstance(post_logout_redirect_uri, URL): + post_logout_redirect_uri = str(post_logout_redirect_uri) + result = await self.create_logout_url( + post_logout_redirect_uri=post_logout_redirect_uri, + id_token_hint=id_token_hint, + **kwargs, + ) + if result.get("state"): + await self.framework.set_state_data( + request.session, + result["state"], + { + "post_logout_redirect_uri": post_logout_redirect_uri, + }, + ) + return RedirectResponse(result["url"], status_code=302) + + async def validate_logout_response(self, request): + """Validate the state parameter from the logout callback. + + :param request: HTTP request instance from Starlette view. + :return: The state data dict. + :raises OAuthError: If state is missing or invalid. + """ + state = request.query_params.get("state") + if not state: + raise OAuthError(description='Missing "state" parameter') + + state_data = await self.framework.get_state_data(request.session, state) + if not state_data: + raise OAuthError(description='Invalid "state" parameter') + + await self.framework.clear_state_data(request.session, state) + return state_data + + async def authorize_access_token(self, request, **kwargs): + if request.scope.get("method", "GET") == "GET": + error = request.query_params.get("error") + if error: + description = request.query_params.get("error_description") + raise OAuthError(error=error, description=description) + + params = { + "code": request.query_params.get("code"), + "state": request.query_params.get("state"), + } + else: + async with request.form() as form: + params = { + "code": form.get("code"), + "state": form.get("state"), + } + + state_data = await self.framework.get_state_data( + request.session, params.get("state") + ) + await self.framework.clear_state_data(request.session, params.get("state")) + params = self._format_state_params(state_data, params) + + claims_options = kwargs.pop("claims_options", None) + claims_cls = kwargs.pop("claims_cls", None) + leeway = kwargs.pop("leeway", 120) + token = await self.fetch_access_token(**params, **kwargs) + + if "id_token" in token and "nonce" in state_data: + userinfo = await self.parse_id_token( + token, + nonce=state_data["nonce"], + claims_options=claims_options, + claims_cls=claims_cls, + leeway=leeway, + ) + token["userinfo"] = userinfo + return token diff --git a/authlib/integrations/starlette_client/integration.py b/authlib/integrations/starlette_client/integration.py index e8e7eb1de..224c39aaf 100644 --- a/authlib/integrations/starlette_client/integration.py +++ b/authlib/integrations/starlette_client/integration.py @@ -1,24 +1,72 @@ -from starlette.responses import RedirectResponse -from ..httpx_client import AsyncOAuth1Client, AsyncOAuth2Client +import json +import time +from collections.abc import Hashable +from typing import Any + from ..base_client import FrameworkIntegration -from ..base_client.async_app import AsyncRemoteApp -class StartletteIntegration(FrameworkIntegration): - oauth1_client_cls = AsyncOAuth1Client - oauth2_client_cls = AsyncOAuth2Client +class StarletteIntegration(FrameworkIntegration): + async def _get_cache_data(self, key: Hashable): + value = await self.cache.get(key) + if not value: + return None + try: + return json.loads(value) + except (TypeError, ValueError): + return None + + async def get_state_data( + self, session: dict[str, Any] | None, state: str + ) -> dict[str, Any]: + key = f"_state_{self.name}_{state}" + if self.cache: + # require a session-bound marker to prove the callback originates + # from the user-agent that started the flow (RFC 6749 §10.12) + if session is None or session.get(key) is None: + return None + value = await self._get_cache_data(key) + elif session is not None: + value = session.get(key) + else: + value = None + + if value: + return value.get("data") + return None + + async def set_state_data( + self, session: dict[str, Any] | None, state: str, data: Any + ): + key_prefix = f"_state_{self.name}_" + key = f"{key_prefix}{state}" + now = time.time() + if self.cache: + await self.cache.set(key, json.dumps({"data": data}), self.expires_in) + if session is not None: + # clear old state data to avoid session size growing + for old_key in list(session.keys()): + if old_key.startswith(key_prefix): + session.pop(old_key) + session[key] = {"exp": now + self.expires_in} + elif session is not None: + # clear old state data to avoid session size growing + for old_key in list(session.keys()): + if old_key.startswith(key_prefix): + session.pop(old_key) + session[key] = {"data": data, "exp": now + self.expires_in} + + async def clear_state_data(self, session: dict[str, Any] | None, state: str): + key = f"_state_{self.name}_{state}" + if self.cache: + await self.cache.delete(key) + if session is not None: + session.pop(key, None) + self._clear_session_state(session) def update_token(self, token, refresh_token=None, access_token=None): pass - def generate_access_token_params(self, request_token_url, request): - if request_token_url: - return dict(request.query_params) - return { - 'code': request.query_params.get('code'), - 'state': request.query_params.get('state'), - } - @staticmethod def load_config(oauth, name, params): if not oauth.config: @@ -26,41 +74,8 @@ def load_config(oauth, name, params): rv = {} for k in params: - conf_key = '{}_{}'.format(name, k).upper() + conf_key = f"{name}_{k}".upper() v = oauth.config.get(conf_key, default=None) if v is not None: rv[k] = v return rv - - -class StarletteRemoteApp(AsyncRemoteApp): - - async def authorize_redirect(self, request, redirect_uri=None, **kwargs): - """Create a HTTP Redirect for Authorization Endpoint. - - :param request: Starlette Request instance. - :param redirect_uri: Callback or redirect URI for authorization. - :param kwargs: Extra parameters to include. - :return: Starlette ``RedirectResponse`` instance. - """ - rv = await self.create_authorization_url(redirect_uri, **kwargs) - self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) - return RedirectResponse(rv['url'], status_code=302) - - async def authorize_access_token(self, request, **kwargs): - """Fetch an access token. - - :param request: Starlette Request instance. - :return: A token dict. - """ - params = self.retrieve_access_token_params(request) - params.update(kwargs) - return await self.fetch_access_token(**params) - - async def parse_id_token(self, request, token, claims_options=None): - """Return an instance of UserInfo from token's ``id_token``.""" - if 'id_token' not in token: - return None - - nonce = self.framework.get_session_data(request, 'nonce') - return await self._parse_id_token(token, nonce, claims_options) diff --git a/authlib/jose/__init__.py b/authlib/jose/__init__.py index 86db6a709..c13a278fa 100644 --- a/authlib/jose/__init__.py +++ b/authlib/jose/__init__.py @@ -1,40 +1,48 @@ -""" - authlib.jose - ~~~~~~~~~~~~ +"""authlib.jose +~~~~~~~~~~~~ - JOSE implementation in Authlib. Tracking the status of JOSE specs at - https://tools.ietf.org/wg/jose/ +JOSE implementation in Authlib. Tracking the status of JOSE specs at +https://tools.ietf.org/wg/jose/ """ -from .rfc7515 import ( - JsonWebSignature, JWSAlgorithm, JWSHeader, JWSObject, -) -from .rfc7516 import ( - JsonWebEncryption, JWEAlgorithm, JWEEncAlgorithm, JWEZipAlgorithm, -) -from .rfc7517 import Key, KeySet -from .rfc7518 import ( - register_jws_rfc7518, - register_jwe_rfc7518, - ECDHAlgorithm, - OctKey, - RSAKey, - ECKey, -) -from .rfc7519 import JsonWebToken, BaseClaims, JWTClaims -from .rfc8037 import OKPKey, register_jws_rfc8037 -from .drafts import register_jwe_draft + +from authlib.deprecate import deprecate from .errors import JoseError -from .jwk import JsonWebKey +from .rfc7515 import JsonWebSignature +from .rfc7515 import JWSAlgorithm +from .rfc7515 import JWSHeader +from .rfc7515 import JWSObject +from .rfc7516 import JsonWebEncryption +from .rfc7516 import JWEAlgorithm +from .rfc7516 import JWEEncAlgorithm +from .rfc7516 import JWEZipAlgorithm +from .rfc7517 import JsonWebKey +from .rfc7517 import Key +from .rfc7517 import KeySet +from .rfc7518 import ECDHESAlgorithm +from .rfc7518 import ECKey +from .rfc7518 import OctKey +from .rfc7518 import RSAKey +from .rfc7518 import register_jwe_rfc7518 +from .rfc7518 import register_jws_rfc7518 +from .rfc7519 import BaseClaims +from .rfc7519 import JsonWebToken +from .rfc7519 import JWTClaims +from .rfc8037 import OKPKey +from .rfc8037 import register_jws_rfc8037 + +deprecate( + "authlib.jose module is deprecated, please use joserfc instead.", version="2.0.0" +) # register algorithms -register_jws_rfc7518() -register_jwe_rfc7518() -register_jws_rfc8037() -register_jwe_draft() +register_jws_rfc7518(JsonWebSignature) +register_jws_rfc8037(JsonWebSignature) + +register_jwe_rfc7518(JsonWebEncryption) # attach algorithms -ECDHAlgorithm.ALLOWED_KEY_CLS = (ECKey, OKPKey) +ECDHESAlgorithm.ALLOWED_KEY_CLS = (ECKey, OKPKey) # register supported keys JsonWebKey.JWK_KEY_CLS = { @@ -44,32 +52,45 @@ OKPKey.kty: OKPKey, } -# compatible constants -JWS_ALGORITHMS = list(JsonWebSignature.ALGORITHMS_REGISTRY.keys()) -JWE_ALG_ALGORITHMS = list(JsonWebEncryption.ALG_REGISTRY.keys()) -JWE_ENC_ALGORITHMS = list(JsonWebEncryption.ENC_REGISTRY.keys()) -JWE_ZIP_ALGORITHMS = list(JsonWebEncryption.ZIP_REGISTRY.keys()) -JWE_ALGORITHMS = JWE_ALG_ALGORITHMS + JWE_ENC_ALGORITHMS + JWE_ZIP_ALGORITHMS - -# compatible imports -JWS = JsonWebSignature -JWE = JsonWebEncryption -JWK = JsonWebKey -JWT = JsonWebToken - -jwt = JsonWebToken() +jwt = JsonWebToken( + [ + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES256K", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA", + ] +) __all__ = [ - 'JoseError', - - 'JWS', 'JsonWebSignature', 'JWSAlgorithm', 'JWSHeader', 'JWSObject', - 'JWE', 'JsonWebEncryption', 'JWEAlgorithm', 'JWEEncAlgorithm', 'JWEZipAlgorithm', - - 'JWK', 'JsonWebKey', 'Key', 'KeySet', - - 'OctKey', 'RSAKey', 'ECKey', 'OKPKey', - - 'JWT', 'JsonWebToken', 'BaseClaims', 'JWTClaims', - 'jwt', + "JoseError", + "JsonWebSignature", + "JWSAlgorithm", + "JWSHeader", + "JWSObject", + "JsonWebEncryption", + "JWEAlgorithm", + "JWEEncAlgorithm", + "JWEZipAlgorithm", + "JsonWebKey", + "Key", + "KeySet", + "OctKey", + "RSAKey", + "ECKey", + "OKPKey", + "JsonWebToken", + "BaseClaims", + "JWTClaims", + "jwt", ] diff --git a/authlib/jose/drafts/__init__.py b/authlib/jose/drafts/__init__.py index b16013874..c72edb64d 100644 --- a/authlib/jose/drafts/__init__.py +++ b/authlib/jose/drafts/__init__.py @@ -1,3 +1,19 @@ -from ._jwe_enc_cryptography import register_jwe_draft +from ._jwe_algorithms import JWE_DRAFT_ALG_ALGORITHMS +from ._jwe_enc_cryptography import C20PEncAlgorithm -__all__ = ['register_jwe_draft'] +try: + from ._jwe_enc_cryptodome import XC20PEncAlgorithm +except ImportError: + XC20PEncAlgorithm = None + + +def register_jwe_draft(cls): + for alg in JWE_DRAFT_ALG_ALGORITHMS: + cls.register_algorithm(alg) + + cls.register_algorithm(C20PEncAlgorithm(256)) # C20P + if XC20PEncAlgorithm is not None: + cls.register_algorithm(XC20PEncAlgorithm(256)) # XC20P + + +__all__ = ["register_jwe_draft"] diff --git a/authlib/jose/drafts/_jwe_algorithms.py b/authlib/jose/drafts/_jwe_algorithms.py new file mode 100644 index 000000000..1b6269f5c --- /dev/null +++ b/authlib/jose/drafts/_jwe_algorithms.py @@ -0,0 +1,216 @@ +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash + +from authlib.jose.errors import InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError +from authlib.jose.rfc7516 import JWEAlgorithmWithTagAwareKeyAgreement +from authlib.jose.rfc7518 import AESAlgorithm +from authlib.jose.rfc7518 import CBCHS2EncAlgorithm +from authlib.jose.rfc7518 import ECKey +from authlib.jose.rfc7518 import u32be_len_input +from authlib.jose.rfc8037 import OKPKey + + +class ECDH1PUAlgorithm(JWEAlgorithmWithTagAwareKeyAgreement): + EXTRA_HEADERS = ["epk", "apu", "apv", "skid"] + ALLOWED_KEY_CLS = (ECKey, OKPKey) + + # https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04 + def __init__(self, key_size=None): + if key_size is None: + self.name = "ECDH-1PU" + self.description = "ECDH-1PU in the Direct Key Agreement mode" + else: + self.name = f"ECDH-1PU+A{key_size}KW" + self.description = ( + f"ECDH-1PU using Concat KDF and CEK wrapped with A{key_size}KW" + ) + self.key_size = key_size + self.aeskw = AESAlgorithm(key_size) + + def prepare_key(self, raw_data): + if isinstance(raw_data, self.ALLOWED_KEY_CLS): + return raw_data + return ECKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + preset = {"epk": epk, "header": h} + if self.key_size is not None: + cek = enc_alg.generate_cek() + preset["cek"] = cek + return preset + + def compute_shared_key(self, shared_key_e, shared_key_s): + return shared_key_e + shared_key_s + + def compute_fixed_info(self, headers, bit_size, tag): + if tag is None: + cctag = b"" + else: + cctag = u32be_len_input(tag) + + # AlgorithmID + if self.key_size is None: + alg_id = u32be_len_input(headers["enc"]) + else: + alg_id = u32be_len_input(headers["alg"]) + + # PartyUInfo + apu_info = u32be_len_input(headers.get("apu"), True) + + # PartyVInfo + apv_info = u32be_len_input(headers.get("apv"), True) + + # SuppPubInfo + pub_info = struct.pack(">I", bit_size) + cctag + + return alg_id + apu_info + apv_info + pub_info + + def compute_derived_key(self, shared_key, fixed_info, bit_size): + ckdf = ConcatKDFHash( + algorithm=hashes.SHA256(), + length=bit_size // 8, + otherinfo=fixed_info, + backend=default_backend(), + ) + return ckdf.derive(shared_key) + + def deliver_at_sender( + self, + sender_static_key, + sender_ephemeral_key, + recipient_pubkey, + headers, + bit_size, + tag, + ): + shared_key_s = sender_static_key.exchange_shared_key(recipient_pubkey) + shared_key_e = sender_ephemeral_key.exchange_shared_key(recipient_pubkey) + shared_key = self.compute_shared_key(shared_key_e, shared_key_s) + + fixed_info = self.compute_fixed_info(headers, bit_size, tag) + + return self.compute_derived_key(shared_key, fixed_info, bit_size) + + def deliver_at_recipient( + self, + recipient_key, + sender_static_pubkey, + sender_ephemeral_pubkey, + headers, + bit_size, + tag, + ): + shared_key_s = recipient_key.exchange_shared_key(sender_static_pubkey) + shared_key_e = recipient_key.exchange_shared_key(sender_ephemeral_pubkey) + shared_key = self.compute_shared_key(shared_key_e, shared_key_s) + + fixed_info = self.compute_fixed_info(headers, bit_size, tag) + + return self.compute_derived_key(shared_key, fixed_info, bit_size) + + def _generate_ephemeral_key(self, key): + return key.generate_key(key["crv"], is_private=True) + + def _prepare_headers(self, epk): + # REQUIRED_JSON_FIELDS contains only public fields + pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS} + pub_epk["kty"] = epk.kty + return {"epk": pub_epk} + + def generate_keys_and_prepare_headers(self, enc_alg, key, sender_key, preset=None): + if not isinstance(enc_alg, CBCHS2EncAlgorithm): + raise InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError() + + if preset and "epk" in preset: + epk = preset["epk"] + h = {} + else: + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + + if preset and "cek" in preset: + cek = preset["cek"] + else: + cek = enc_alg.generate_cek() + + return {"epk": epk, "cek": cek, "header": h} + + def _agree_upon_key_at_sender( + self, enc_alg, headers, key, sender_key, epk, tag=None + ): + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + public_key = key.get_op_key("wrapKey") + + return self.deliver_at_sender( + sender_key, epk, public_key, headers, bit_size, tag + ) + + def _wrap_cek(self, cek, dk): + kek = self.aeskw.prepare_key(dk) + return self.aeskw.wrap_cek(cek, kek) + + def agree_upon_key_and_wrap_cek( + self, enc_alg, headers, key, sender_key, epk, cek, tag + ): + dk = self._agree_upon_key_at_sender(enc_alg, headers, key, sender_key, epk, tag) + return self._wrap_cek(cek, dk) + + def wrap(self, enc_alg, headers, key, sender_key, preset=None): + # In this class this method is used in direct key agreement mode only + if self.key_size is not None: + raise RuntimeError("Invalid algorithm state detected") + + if preset and "epk" in preset: + epk = preset["epk"] + h = {} + else: + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + + dk = self._agree_upon_key_at_sender(enc_alg, headers, key, sender_key, epk) + + return {"ek": b"", "cek": dk, "header": h} + + def unwrap(self, enc_alg, ek, headers, key, sender_key, tag=None): + if "epk" not in headers: + raise ValueError('Missing "epk" in headers') + + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + sender_pubkey = sender_key.get_op_key("wrapKey") + epk = key.import_key(headers["epk"]) + epk_pubkey = epk.get_op_key("wrapKey") + dk = self.deliver_at_recipient( + key, sender_pubkey, epk_pubkey, headers, bit_size, tag + ) + + if self.key_size is None: + return dk + + kek = self.aeskw.prepare_key(dk) + return self.aeskw.unwrap(enc_alg, ek, headers, kek) + + +JWE_DRAFT_ALG_ALGORITHMS = [ + ECDH1PUAlgorithm(None), # ECDH-1PU + ECDH1PUAlgorithm(128), # ECDH-1PU+A128KW + ECDH1PUAlgorithm(192), # ECDH-1PU+A192KW + ECDH1PUAlgorithm(256), # ECDH-1PU+A256KW +] + + +def register_jwe_alg_draft(cls): + for alg in JWE_DRAFT_ALG_ALGORITHMS: + cls.register_algorithm(alg) diff --git a/authlib/jose/drafts/_jwe_enc_cryptodome.py b/authlib/jose/drafts/_jwe_enc_cryptodome.py new file mode 100644 index 000000000..e53e35318 --- /dev/null +++ b/authlib/jose/drafts/_jwe_enc_cryptodome.py @@ -0,0 +1,53 @@ +"""authlib.jose.draft. +~~~~~~~~~~~~~~~~~~~~ + +Content Encryption per `Section 4`_. + +.. _`Section 4`: https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4 +""" + +from Cryptodome.Cipher import ChaCha20_Poly1305 as Cryptodome_ChaCha20_Poly1305 + +from authlib.jose.rfc7516 import JWEEncAlgorithm + + +class XC20PEncAlgorithm(JWEEncAlgorithm): + # Use of an IV of size 192 bits is REQUIRED with this algorithm. + # https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4.1 + IV_SIZE = 192 + + def __init__(self, key_size): + self.name = "XC20P" + self.description = "XChaCha20-Poly1305" + self.key_size = key_size + self.CEK_SIZE = key_size + + def encrypt(self, msg, aad, iv, key): + """Content Encryption with AEAD_XCHACHA20_POLY1305. + + :param msg: text to be encrypt in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param key: encrypted key in bytes + :return: (ciphertext, tag) + """ + self.check_iv(iv) + chacha = Cryptodome_ChaCha20_Poly1305.new(key=key, nonce=iv) + chacha.update(aad) + ciphertext, tag = chacha.encrypt_and_digest(msg) + return ciphertext, tag + + def decrypt(self, ciphertext, aad, iv, tag, key): + """Content Decryption with AEAD_XCHACHA20_POLY1305. + + :param ciphertext: ciphertext in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param tag: authentication tag in bytes + :param key: encrypted key in bytes + :return: message + """ + self.check_iv(iv) + chacha = Cryptodome_ChaCha20_Poly1305.new(key=key, nonce=iv) + chacha.update(aad) + return chacha.decrypt_and_verify(ciphertext, tag) diff --git a/authlib/jose/drafts/_jwe_enc_cryptography.py b/authlib/jose/drafts/_jwe_enc_cryptography.py index 806eab93e..f689c30dc 100644 --- a/authlib/jose/drafts/_jwe_enc_cryptography.py +++ b/authlib/jose/drafts/_jwe_enc_cryptography.py @@ -1,28 +1,29 @@ -""" - authlib.jose.draft - ~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.draft. +~~~~~~~~~~~~~~~~~~~~ - Content Encryption per `Section 4`_. +Content Encryption per `Section 4`_. - .. _`Section 4`: https://tools.ietf.org/html/draft-amringer-jose-chacha-02#section-4 +.. _`Section 4`: https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4 """ + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 -from authlib.jose.rfc7516 import JWEEncAlgorithm, JsonWebEncryption + +from authlib.jose.rfc7516 import JWEEncAlgorithm class C20PEncAlgorithm(JWEEncAlgorithm): # Use of an IV of size 96 bits is REQUIRED with this algorithm. - # https://tools.ietf.org/html/draft-amringer-jose-chacha-02#section-2.2.1 + # https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4.1 IV_SIZE = 96 def __init__(self, key_size): - self.name = 'C20P' - self.description = 'ChaCha20-Poly1305' + self.name = "C20P" + self.description = "ChaCha20-Poly1305" self.key_size = key_size self.CEK_SIZE = key_size def encrypt(self, msg, aad, iv, key): - """Key Encryption with AES GCM + """Content Encryption with AEAD_CHACHA20_POLY1305. :param msg: text to be encrypt in bytes :param aad: additional authenticated data in bytes @@ -36,7 +37,7 @@ def encrypt(self, msg, aad, iv, key): return ciphertext[:-16], ciphertext[-16:] def decrypt(self, ciphertext, aad, iv, tag, key): - """Key Decryption with AES GCM + """Content Decryption with AEAD_CHACHA20_POLY1305. :param ciphertext: ciphertext in bytes :param aad: additional authenticated data in bytes @@ -48,7 +49,3 @@ def decrypt(self, ciphertext, aad, iv, tag, key): self.check_iv(iv) chacha = ChaCha20Poly1305(key) return chacha.decrypt(iv, ciphertext + tag, aad) - - -def register_jwe_draft(): - JsonWebEncryption.register_algorithm(C20PEncAlgorithm(256)) # C20P diff --git a/authlib/jose/errors.py b/authlib/jose/errors.py index 2174b42e9..385a866ef 100644 --- a/authlib/jose/errors.py +++ b/authlib/jose/errors.py @@ -6,83 +6,115 @@ class JoseError(AuthlibBaseError): class DecodeError(JoseError): - error = 'decode_error' + error = "decode_error" class MissingAlgorithmError(JoseError): - error = 'missing_algorithm' + error = "missing_algorithm" class UnsupportedAlgorithmError(JoseError): - error = 'unsupported_algorithm' + error = "unsupported_algorithm" class BadSignatureError(JoseError): - error = 'bad_signature' + error = "bad_signature" def __init__(self, result): - super(BadSignatureError, self).__init__() + super().__init__() self.result = result -class InvalidHeaderParameterName(JoseError): - error = 'invalid_header_parameter_name' +class InvalidHeaderParameterNameError(JoseError): + error = "invalid_header_parameter_name" def __init__(self, name): - description = 'Invalid Header Parameter Names: {}'.format(name) - super(InvalidHeaderParameterName, self).__init__( - description=description) + description = f"Invalid Header Parameter Name: {name}" + super().__init__(description=description) + + +class InvalidCritHeaderParameterNameError(JoseError): + error = "invalid_crit_header_parameter_name" + + def __init__(self, name): + description = f"Invalid Header Parameter Name: {name}" + super().__init__(description=description) + + +class InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError(JoseError): + error = "invalid_encryption_algorithm_for_ECDH_1PU_with_key_wrapping" + + def __init__(self): + description = ( + "In key agreement with key wrapping mode ECDH-1PU algorithm " + "only supports AES_CBC_HMAC_SHA2 family encryption algorithms" + ) + super().__init__(description=description) + + +class InvalidAlgorithmForMultipleRecipientsMode(JoseError): + error = "invalid_algorithm_for_multiple_recipients_mode" + + def __init__(self, alg): + description = f"{alg} algorithm cannot be used in multiple recipients mode" + super().__init__(description=description) + + +class KeyMismatchError(JoseError): + error = "key_mismatch_error" + description = "Key does not match to any recipient" class MissingEncryptionAlgorithmError(JoseError): - error = 'missing_encryption_algorithm' - description = 'Missing "enc" in header' + error = "missing_encryption_algorithm" + description = "Missing 'enc' in header" class UnsupportedEncryptionAlgorithmError(JoseError): - error = 'unsupported_encryption_algorithm' - description = 'Unsupported "enc" value in header' + error = "unsupported_encryption_algorithm" + description = "Unsupported 'enc' value in header" class UnsupportedCompressionAlgorithmError(JoseError): - error = 'unsupported_compression_algorithm' - description = 'Unsupported "zip" value in header' + error = "unsupported_compression_algorithm" + description = "Unsupported 'zip' value in header" class InvalidUseError(JoseError): - error = 'invalid_use' - description = 'Key "use" is not valid for your usage' + error = "invalid_use" + description = "Key 'use' is not valid for your usage" class InvalidClaimError(JoseError): - error = 'invalid_claim' + error = "invalid_claim" def __init__(self, claim): - description = 'Invalid claim "{}"'.format(claim) - super(InvalidClaimError, self).__init__(description=description) + self.claim_name = claim + description = f"Invalid claim '{claim}'" + super().__init__(description=description) class MissingClaimError(JoseError): - error = 'missing_claim' + error = "missing_claim" def __init__(self, claim): - description = 'Missing "{}" claim'.format(claim) - super(MissingClaimError, self).__init__(description=description) + description = f"Missing '{claim}' claim" + super().__init__(description=description) class InsecureClaimError(JoseError): - error = 'insecure_claim' + error = "insecure_claim" def __init__(self, claim): - description = 'Insecure claim "{}"'.format(claim) - super(InsecureClaimError, self).__init__(description=description) + description = f"Insecure claim '{claim}'" + super().__init__(description=description) class ExpiredTokenError(JoseError): - error = 'expired_token' - description = 'The token is expired' + error = "expired_token" + description = "The token is expired" class InvalidTokenError(JoseError): - error = 'invalid_token' - description = 'The token is not valid yet' + error = "invalid_token" + description = "The token is not valid yet" diff --git a/authlib/jose/jwk.py b/authlib/jose/jwk.py index c78ef70cf..e1debb57c 100644 --- a/authlib/jose/jwk.py +++ b/authlib/jose/jwk.py @@ -1,80 +1,10 @@ -from authlib.common.encoding import text_types, json_loads -from .rfc7517 import KeySet -from .rfc7518 import ( - OctKey, - RSAKey, - ECKey, - load_pem_key, -) -from .rfc8037 import OKPKey +from authlib.deprecate import deprecate - -class JsonWebKey(object): - JWK_KEY_CLS = { - OctKey.kty: OctKey, - RSAKey.kty: RSAKey, - ECKey.kty: ECKey, - OKPKey.kty: OKPKey, - } - - @classmethod - def generate_key(cls, kty, crv_or_size, options=None, is_private=False): - """Generate a Key with the given key type, curve name or bit size. - - :param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP`` - :param crv_or_size: curve name or bit size - :param options: a dict of other options for Key - :param is_private: create a private key or public key - :return: Key instance - """ - key_cls = cls.JWK_KEY_CLS[kty] - return key_cls.generate_key(crv_or_size, options, is_private) - - @classmethod - def import_key(cls, raw, options=None): - """Import a Key from bytes, string, PEM or dict. - - :return: Key instance - """ - kty = None - if options is not None: - kty = options.get('kty') - - if kty is None and isinstance(raw, dict): - kty = raw.get('kty') - - if kty is None: - raw_key = load_pem_key(raw) - for _kty in cls.JWK_KEY_CLS: - key_cls = cls.JWK_KEY_CLS[_kty] - if isinstance(raw_key, key_cls.RAW_KEY_CLS): - return key_cls.import_key(raw_key, options) - - key_cls = cls.JWK_KEY_CLS[kty] - return key_cls.import_key(raw, options) - - @classmethod - def import_key_set(cls, raw): - """Import KeySet from string, dict or a list of keys. - - :return: KeySet instance - """ - if isinstance(raw, text_types) and \ - raw.startswith('{') and raw.endswith('}'): - raw = json_loads(raw) - keys = raw.get('keys') - elif isinstance(raw, dict) and 'keys' in raw: - keys = raw.get('keys') - elif isinstance(raw, (tuple, list)): - keys = raw - else: - return None - - return KeySet([cls.import_key(k) for k in keys]) +from .rfc7517 import JsonWebKey def loads(obj, kid=None): - # TODO: deprecate + deprecate("Please use ``JsonWebKey`` directly.") key_set = JsonWebKey.import_key_set(obj) if key_set: return key_set.find_by_kid(kid) @@ -82,10 +12,9 @@ def loads(obj, kid=None): def dumps(key, kty=None, **params): - # TODO: deprecate + deprecate("Please use ``JsonWebKey`` directly.") if kty: - params['kty'] = kty + params["kty"] = kty key = JsonWebKey.import_key(key, params) - data = key.as_dict() - return data + return dict(key) diff --git a/authlib/jose/rfc7515/__init__.py b/authlib/jose/rfc7515/__init__.py index 5f8e0f5f5..7c657515d 100644 --- a/authlib/jose/rfc7515/__init__.py +++ b/authlib/jose/rfc7515/__init__.py @@ -1,18 +1,15 @@ -""" - authlib.jose.rfc7515 - ~~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.rfc7515. +~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - JSON Web Signature (JWS). +This module represents a direct implementation of +JSON Web Signature (JWS). - https://tools.ietf.org/html/rfc7515 +https://tools.ietf.org/html/rfc7515 """ from .jws import JsonWebSignature -from .models import JWSAlgorithm, JWSHeader, JWSObject - +from .models import JWSAlgorithm +from .models import JWSHeader +from .models import JWSObject -__all__ = [ - 'JsonWebSignature', - 'JWSAlgorithm', 'JWSHeader', 'JWSObject' -] +__all__ = ["JsonWebSignature", "JWSAlgorithm", "JWSHeader", "JWSObject"] diff --git a/authlib/jose/rfc7515/jws.py b/authlib/jose/rfc7515/jws.py index 20920559b..92e24ce52 100644 --- a/authlib/jose/rfc7515/jws.py +++ b/authlib/jose/rfc7515/jws.py @@ -1,32 +1,40 @@ -from authlib.common.encoding import ( - to_bytes, - to_unicode, - urlsafe_b64encode, - json_b64encode, - json_loads, -) -from authlib.jose.util import ( - extract_header, - extract_segment, -) -from authlib.jose.errors import ( - DecodeError, - MissingAlgorithmError, - UnsupportedAlgorithmError, - BadSignatureError, - InvalidHeaderParameterName, -) -from .models import JWSHeader, JWSObject - - -class JsonWebSignature(object): - +from authlib.common.encoding import json_b64encode +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64encode +from authlib.jose.errors import BadSignatureError +from authlib.jose.errors import DecodeError +from authlib.jose.errors import InvalidCritHeaderParameterNameError +from authlib.jose.errors import InvalidHeaderParameterNameError +from authlib.jose.errors import MissingAlgorithmError +from authlib.jose.errors import UnsupportedAlgorithmError +from authlib.jose.util import ensure_dict +from authlib.jose.util import extract_header +from authlib.jose.util import extract_segment + +from .models import JWSHeader +from .models import JWSObject + + +class JsonWebSignature: #: Registered Header Parameter Names defined by Section 4.1 - REGISTERED_HEADER_PARAMETER_NAMES = frozenset([ - 'alg', 'jku', 'jwk', 'kid', - 'x5u', 'x5c', 'x5t', 'x5t#S256', - 'typ', 'cty', 'crit' - ]) + REGISTERED_HEADER_PARAMETER_NAMES = frozenset( + [ + "alg", + "jku", + "jwk", + "kid", + "x5u", + "x5c", + "x5t", + "x5t#S256", + "typ", + "cty", + "crit", + ] + ) + + MAX_CONTENT_LENGTH: int = 256000 #: Defined available JWS algorithms in the registry ALGORITHMS_REGISTRY = {} @@ -37,9 +45,8 @@ def __init__(self, algorithms=None, private_headers=None): @classmethod def register_algorithm(cls, algorithm): - if not algorithm or algorithm.algorithm_type != 'JWS': - raise ValueError( - 'Invalid algorithm for JWS, {!r}'.format(algorithm)) + if not algorithm or algorithm.algorithm_type != "JWS": + raise ValueError(f"Invalid algorithm for JWS, {algorithm!r}") cls.ALGORITHMS_REGISTRY[algorithm.name] = algorithm def serialize_compact(self, protected, payload, key): @@ -60,15 +67,16 @@ def serialize_compact(self, protected, payload, key): """ jws_header = JWSHeader(protected, None) self._validate_private_headers(protected) + self._validate_crit_headers(protected) algorithm, key = self._prepare_algorithm_key(protected, payload, key) protected_segment = json_b64encode(jws_header.protected) payload_segment = urlsafe_b64encode(to_bytes(payload)) # calculate signature - signing_input = b'.'.join([protected_segment, payload_segment]) + signing_input = b".".join([protected_segment, payload_segment]) signature = urlsafe_b64encode(algorithm.sign(signing_input, key)) - return b'.'.join([protected_segment, payload_segment, signature]) + return b".".join([protected_segment, payload_segment, signature]) def deserialize_compact(self, s, key, decode=None): """Exact JWS Compact Serialization, and validate with the given key. @@ -83,14 +91,18 @@ def deserialize_compact(self, s, key, decode=None): .. _`Section 7.1`: https://tools.ietf.org/html/rfc7515#section-7.1 """ + if len(s) > self.MAX_CONTENT_LENGTH: + raise ValueError("Serialization is too long.") + try: s = to_bytes(s) - signing_input, signature_segment = s.rsplit(b'.', 1) - protected_segment, payload_segment = signing_input.split(b'.', 1) - except ValueError: - raise DecodeError('Not enough segments') + signing_input, signature_segment = s.rsplit(b".", 1) + protected_segment, payload_segment = signing_input.split(b".", 1) + except ValueError as exc: + raise DecodeError("Not enough segments") from exc protected = _extract_header(protected_segment) + self._validate_crit_headers(protected) jws_header = JWSHeader(protected, None) payload = _extract_payload(payload_segment) @@ -98,7 +110,7 @@ def deserialize_compact(self, s, key, decode=None): payload = decode(payload) signature = _extract_signature(signature_segment) - rv = JWSObject(jws_header, payload, 'compact') + rv = JWSObject(jws_header, payload, "compact") algorithm, key = self._prepare_algorithm_key(jws_header, payload, key) if algorithm.verify(signing_input, signature, key): return rv @@ -128,30 +140,32 @@ def serialize_json(self, header_obj, payload, key): def _sign(jws_header): self._validate_private_headers(jws_header) + # RFC 7515 §4.1.11: 'crit' MUST be integrity-protected. + # Reject if present in unprotected header, and validate only + # against the protected header parameters. + self._reject_unprotected_crit(jws_header.header) + self._validate_crit_headers(jws_header.protected) _alg, _key = self._prepare_algorithm_key(jws_header, payload, key) protected_segment = json_b64encode(jws_header.protected) - signing_input = b'.'.join([protected_segment, payload_segment]) + signing_input = b".".join([protected_segment, payload_segment]) signature = urlsafe_b64encode(_alg.sign(signing_input, _key)) rv = { - 'protected': to_unicode(protected_segment), - 'signature': to_unicode(signature) + "protected": to_unicode(protected_segment), + "signature": to_unicode(signature), } if jws_header.header is not None: - rv['header'] = jws_header.header + rv["header"] = jws_header.header return rv if isinstance(header_obj, dict): data = _sign(JWSHeader.from_dict(header_obj)) - data['payload'] = to_unicode(payload_segment) + data["payload"] = to_unicode(payload_segment) return data signatures = [_sign(JWSHeader.from_dict(h)) for h in header_obj] - return { - 'payload': to_unicode(payload_segment), - 'signatures': signatures - } + return {"payload": to_unicode(payload_segment), "signatures": signatures} def deserialize_json(self, obj, key, decode=None): """Exact JWS JSON Serialization, and validate with the given key. @@ -166,10 +180,10 @@ def deserialize_json(self, obj, key, decode=None): .. _`Section 7.2`: https://tools.ietf.org/html/rfc7515#section-7.2 """ - obj = _ensure_dict(obj) + obj = ensure_dict(obj, "JWS") - payload_segment = obj.get('payload') - if not payload_segment: + payload_segment = obj.get("payload") + if payload_segment is None: raise DecodeError('Missing "payload" value') payload_segment = to_bytes(payload_segment) @@ -177,26 +191,28 @@ def deserialize_json(self, obj, key, decode=None): if decode: payload = decode(payload) - if 'signatures' not in obj: + if "signatures" not in obj: # flattened JSON JWS jws_header, valid = self._validate_json_jws( - payload_segment, payload, obj, key) + payload_segment, payload, obj, key + ) - rv = JWSObject(jws_header, payload, 'flat') + rv = JWSObject(jws_header, payload, "flat") if valid: return rv raise BadSignatureError(rv) headers = [] is_valid = True - for header_obj in obj['signatures']: + for header_obj in obj["signatures"]: jws_header, valid = self._validate_json_jws( - payload_segment, payload, header_obj, key) + payload_segment, payload, header_obj, key + ) headers.append(jws_header) if not valid: is_valid = False - rv = JWSObject(headers, payload, 'json') + rv = JWSObject(headers, payload, "json") if is_valid: return rv raise BadSignatureError(rv) @@ -215,7 +231,7 @@ def serialize(self, header, payload, key): """ if isinstance(header, (list, tuple)): return self.serialize_json(header, payload, key) - if 'protected' in header: + if "protected" in header: return self.serialize_json(header, payload, key) return self.serialize_compact(header, payload, key) @@ -236,25 +252,27 @@ def deserialize(self, s, key, decode=None): return self.deserialize_json(s, key, decode) s = to_bytes(s) - if s.startswith(b'{') and s.endswith(b'}'): + if s.startswith(b"{") and s.endswith(b"}"): return self.deserialize_json(s, key, decode) return self.deserialize_compact(s, key, decode) def _prepare_algorithm_key(self, header, payload, key): - if 'alg' not in header: + if "alg" not in header: raise MissingAlgorithmError() - alg = header['alg'] - if self._algorithms and alg not in self._algorithms: - raise UnsupportedAlgorithmError() + alg = header["alg"] if alg not in self.ALGORITHMS_REGISTRY: raise UnsupportedAlgorithmError() algorithm = self.ALGORITHMS_REGISTRY[alg] + if self._algorithms is None: + if algorithm.deprecated: + raise UnsupportedAlgorithmError() + elif alg not in self._algorithms: + raise UnsupportedAlgorithmError() + if callable(key): key = key(header, payload) - elif 'jwk' in header: - key = header['jwk'] key = algorithm.prepare_key(key) return algorithm, key @@ -267,26 +285,55 @@ def _validate_private_headers(self, header): for k in header: if k not in names: - raise InvalidHeaderParameterName(k) + raise InvalidHeaderParameterNameError(k) + + def _reject_unprotected_crit(self, unprotected_header): + """Reject 'crit' when found in the unprotected header (RFC 7515 §4.1.11).""" + if unprotected_header and "crit" in unprotected_header: + raise InvalidHeaderParameterNameError("crit") + + def _validate_crit_headers(self, header): + if "crit" in header: + crit_headers = header["crit"] + # Type enforcement for robustness and predictable errors + if not isinstance(crit_headers, list) or not all( + isinstance(x, str) for x in crit_headers + ): + raise InvalidHeaderParameterNameError("crit") + names = self.REGISTERED_HEADER_PARAMETER_NAMES.copy() + if self._private_headers: + names = names.union(self._private_headers) + for k in crit_headers: + if k not in names: + raise InvalidCritHeaderParameterNameError(k) + elif k not in header: + raise InvalidCritHeaderParameterNameError(k) def _validate_json_jws(self, payload_segment, payload, header_obj, key): - protected_segment = header_obj.get('protected') + protected_segment = header_obj.get("protected") if not protected_segment: raise DecodeError('Missing "protected" value') - signature_segment = header_obj.get('signature') + signature_segment = header_obj.get("signature") if not signature_segment: raise DecodeError('Missing "signature" value') protected_segment = to_bytes(protected_segment) protected = _extract_header(protected_segment) - header = header_obj.get('header') + header = header_obj.get("header") if header and not isinstance(header, dict): raise DecodeError('Invalid "header" value') - + # RFC 7515 §4.1.11: 'crit' MUST be integrity-protected. If present in + # the unprotected header object, reject the JWS. + self._reject_unprotected_crit(header) + + # Enforce must-understand semantics for names listed in protected + # 'crit'. This will also ensure each listed name is present in the + # protected header. + self._validate_crit_headers(protected) jws_header = JWSHeader(protected, header) algorithm, key = self._prepare_algorithm_key(jws_header, payload, key) - signing_input = b'.'.join([protected_segment, payload_segment]) + signing_input = b".".join([protected_segment, payload_segment]) signature = _extract_signature(to_bytes(signature_segment)) if algorithm.verify(signing_input, signature, key): return jws_header, True @@ -298,21 +345,8 @@ def _extract_header(header_segment): def _extract_signature(signature_segment): - return extract_segment(signature_segment, DecodeError, 'signature') + return extract_segment(signature_segment, DecodeError, "signature") def _extract_payload(payload_segment): - return extract_segment(payload_segment, DecodeError, 'payload') - - -def _ensure_dict(s): - if not isinstance(s, dict): - try: - s = json_loads(to_unicode(s)) - except (ValueError, TypeError): - raise DecodeError('Invalid JWS') - - if not isinstance(s, dict): - raise DecodeError('Invalid JWS') - - return s + return extract_segment(payload_segment, DecodeError, "payload") diff --git a/authlib/jose/rfc7515/models.py b/authlib/jose/rfc7515/models.py index caccfb4e7..b1261b421 100644 --- a/authlib/jose/rfc7515/models.py +++ b/authlib/jose/rfc7515/models.py @@ -1,11 +1,13 @@ -class JWSAlgorithm(object): +class JWSAlgorithm: """Interface for JWS algorithm. JWA specification (RFC7518) SHOULD implement the algorithms for JWS with this base implementation. """ + name = None description = None - algorithm_type = 'JWS' - algorithm_location = 'alg' + deprecated = False + algorithm_type = "JWS" + algorithm_location = "alg" def prepare_key(self, raw_data): """Prepare key for signing and verifying signature.""" @@ -35,8 +37,8 @@ class JWSHeader(dict): """Header object for JWS. It combine the protected header and unprotected header together. JWSHeader itself is a dict of the combined dict. e.g. - >>> protected = {'alg': 'HS256'} - >>> header = {'kid': 'a'} + >>> protected = {"alg": "HS256"} + >>> header = {"kid": "a"} >>> jws_header = JWSHeader(protected, header) >>> print(jws_header) {'alg': 'HS256', 'kid': 'a'} @@ -46,13 +48,14 @@ class JWSHeader(dict): :param protected: dict of protected header :param header: dict of unprotected header """ + def __init__(self, protected, header): obj = {} - if protected: - obj.update(protected) if header: obj.update(header) - super(JWSHeader, self).__init__(obj) + if protected: + obj.update(protected) + super().__init__(obj) self.protected = protected self.header = header @@ -60,13 +63,14 @@ def __init__(self, protected, header): def from_dict(cls, obj): if isinstance(obj, cls): return obj - return cls(obj.get('protected'), obj.get('header')) + return cls(obj.get("protected"), obj.get("header")) class JWSObject(dict): """A dict instance to represent a JWS object.""" - def __init__(self, header, payload, type='compact'): - super(JWSObject, self).__init__( + + def __init__(self, header, payload, type="compact"): + super().__init__( header=header, payload=payload, ) @@ -77,5 +81,5 @@ def __init__(self, header, payload, type='compact'): @property def headers(self): """Alias of ``header`` for JSON typed JWS.""" - if self.type == 'json': - return self['header'] + if self.type == "json": + return self["header"] diff --git a/authlib/jose/rfc7516/__init__.py b/authlib/jose/rfc7516/__init__.py index f7f3c3157..e38e17841 100644 --- a/authlib/jose/rfc7516/__init__.py +++ b/authlib/jose/rfc7516/__init__.py @@ -1,18 +1,22 @@ -""" - authlib.jose.rfc7516 - ~~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.rfc7516. +~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - JSON Web Encryption (JWE). +This module represents a direct implementation of +JSON Web Encryption (JWE). - https://tools.ietf.org/html/rfc7516 +https://tools.ietf.org/html/rfc7516 """ from .jwe import JsonWebEncryption -from .models import JWEAlgorithm, JWEEncAlgorithm, JWEZipAlgorithm - +from .models import JWEAlgorithm +from .models import JWEAlgorithmWithTagAwareKeyAgreement +from .models import JWEEncAlgorithm +from .models import JWEZipAlgorithm __all__ = [ - 'JsonWebEncryption', - 'JWEAlgorithm', 'JWEEncAlgorithm', 'JWEZipAlgorithm' + "JsonWebEncryption", + "JWEAlgorithm", + "JWEAlgorithmWithTagAwareKeyAgreement", + "JWEEncAlgorithm", + "JWEZipAlgorithm", ] diff --git a/authlib/jose/rfc7516/jwe.py b/authlib/jose/rfc7516/jwe.py index 0e5d84de3..3cfc93729 100644 --- a/authlib/jose/rfc7516/jwe.py +++ b/authlib/jose/rfc7516/jwe.py @@ -1,29 +1,46 @@ -from authlib.common.encoding import ( - to_bytes, urlsafe_b64encode, json_b64encode -) -from authlib.jose.util import ( - extract_header, - extract_segment, -) -from authlib.jose.errors import ( - DecodeError, - MissingAlgorithmError, - UnsupportedAlgorithmError, - MissingEncryptionAlgorithmError, - UnsupportedEncryptionAlgorithmError, - UnsupportedCompressionAlgorithmError, - InvalidHeaderParameterName, -) - - -class JsonWebEncryption(object): +from collections import OrderedDict +from copy import deepcopy + +from authlib.common.encoding import json_b64encode +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64encode +from authlib.jose.errors import DecodeError +from authlib.jose.errors import InvalidAlgorithmForMultipleRecipientsMode +from authlib.jose.errors import InvalidHeaderParameterNameError +from authlib.jose.errors import KeyMismatchError +from authlib.jose.errors import MissingAlgorithmError +from authlib.jose.errors import MissingEncryptionAlgorithmError +from authlib.jose.errors import UnsupportedAlgorithmError +from authlib.jose.errors import UnsupportedCompressionAlgorithmError +from authlib.jose.errors import UnsupportedEncryptionAlgorithmError +from authlib.jose.rfc7516.models import JWEAlgorithmWithTagAwareKeyAgreement +from authlib.jose.rfc7516.models import JWEHeader +from authlib.jose.rfc7516.models import JWESharedHeader +from authlib.jose.util import ensure_dict +from authlib.jose.util import extract_header +from authlib.jose.util import extract_segment + + +class JsonWebEncryption: #: Registered Header Parameter Names defined by Section 4.1 - REGISTERED_HEADER_PARAMETER_NAMES = frozenset([ - 'alg', 'enc', 'zip', - 'jku', 'jwk', 'kid', - 'x5u', 'x5c', 'x5t', 'x5t#S256', - 'typ', 'cty', 'crit' - ]) + REGISTERED_HEADER_PARAMETER_NAMES = frozenset( + [ + "alg", + "enc", + "zip", + "jku", + "jwk", + "kid", + "x5u", + "x5c", + "x5t", + "x5t#S256", + "typ", + "cty", + "crit", + ] + ) ALG_REGISTRY = {} ENC_REGISTRY = {} @@ -36,21 +53,21 @@ def __init__(self, algorithms=None, private_headers=None): @classmethod def register_algorithm(cls, algorithm): """Register an algorithm for ``alg`` or ``enc`` or ``zip`` of JWE.""" - if not algorithm or algorithm.algorithm_type != 'JWE': - raise ValueError( - 'Invalid algorithm for JWE, {!r}'.format(algorithm)) + if not algorithm or algorithm.algorithm_type != "JWE": + raise ValueError(f"Invalid algorithm for JWE, {algorithm!r}") - if algorithm.algorithm_location == 'alg': + if algorithm.algorithm_location == "alg": cls.ALG_REGISTRY[algorithm.name] = algorithm - elif algorithm.algorithm_location == 'enc': + elif algorithm.algorithm_location == "enc": cls.ENC_REGISTRY[algorithm.name] = algorithm - elif algorithm.algorithm_location == 'zip': + elif algorithm.algorithm_location == "zip": cls.ZIP_REGISTRY[algorithm.name] = algorithm - def serialize_compact(self, protected, payload, key): - """Generate a JWE Compact Serialization. The JWE Compact Serialization - represents encrypted content as a compact, URL-safe string. This - string is: + def serialize_compact(self, protected, payload, key, sender_key=None): + """Generate a JWE Compact Serialization. + + The JWE Compact Serialization represents encrypted content as a compact, + URL-safe string. This string is:: BASE64URL(UTF8(JWE Protected Header)) || '.' || BASE64URL(JWE Encrypted Key) || '.' || @@ -63,30 +80,54 @@ def serialize_compact(self, protected, payload, key): Per-Recipient Unprotected Header, or JWE AAD values. :param protected: A dict of protected header - :param payload: A string/dict of payload - :param key: Private key used to generate signature - :return: byte + :param payload: Payload (bytes or a value convertible to bytes) + :param key: Public key used to encrypt payload + :param sender_key: Sender's private key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: JWE compact serialization as bytes """ - # step 1: Prepare algorithms & key alg = self.get_header_alg(protected) enc = self.get_header_enc(protected) zip_alg = self.get_header_zip(protected) + + self._validate_sender_key(sender_key, alg) self._validate_private_headers(protected, alg) key = prepare_key(alg, protected, key) + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) # self._post_validate_header(protected, algorithm) # step 2: Generate a random Content Encryption Key (CEK) - # use enc_alg.generate_cek() in .wrap method + # use enc_alg.generate_cek() in scope of upcoming .wrap + # or .generate_keys_and_prepare_headers call # step 3: Encrypt the CEK with the recipient's public key - wrapped = alg.wrap(enc, protected, key) - cek = wrapped['cek'] - ek = wrapped['ek'] - if 'header' in wrapped: - protected.update(wrapped['header']) + if ( + isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) + and alg.key_size is not None + ): + # For a JWE algorithm with tag-aware key agreement in case key agreement + # with key wrapping mode is used: + # Defer key agreement with key wrapping until + # authentication tag is computed + prep = alg.generate_keys_and_prepare_headers(enc, key, sender_key) + epk = prep["epk"] + cek = prep["cek"] + protected.update(prep["header"]) + else: + # In any other case: + # Keep the normal steps order defined by RFC 7516 + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + wrapped = alg.wrap(enc, protected, key, sender_key) + else: + wrapped = alg.wrap(enc, protected, key) + cek = wrapped["cek"] + ek = wrapped["ek"] + if "header" in wrapped: + protected.update(wrapped["header"]) # step 4: Generate a random JWE Initialization Vector iv = enc.generate_iv() @@ -94,7 +135,7 @@ def serialize_compact(self, protected, payload, key): # step 5: Let the Additional Authenticated Data encryption parameter # be ASCII(BASE64URL(UTF8(JWE Protected Header))) protected_segment = json_b64encode(protected) - aad = to_bytes(protected_segment, 'ascii') + aad = to_bytes(protected_segment, "ascii") # step 6: compress message if required if zip_alg: @@ -104,43 +145,492 @@ def serialize_compact(self, protected, payload, key): # step 7: perform encryption ciphertext, tag = enc.encrypt(msg, aad, iv, cek) - return b'.'.join([ - protected_segment, - urlsafe_b64encode(ek), - urlsafe_b64encode(iv), - urlsafe_b64encode(ciphertext), - urlsafe_b64encode(tag) - ]) - - def deserialize_compact(self, s, key, decode=None): - """Exact JWS Compact Serialization, and validate with the given key. - - :param s: text of JWS Compact Serialization - :param key: key used to verify the signature - :param decode: a function to decode plaintext data - :return: dict + + if ( + isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) + and alg.key_size is not None + ): + # For a JWE algorithm with tag-aware key agreement in case key agreement + # with key wrapping mode is used: + # Perform key agreement with key wrapping deferred at step 3 + wrapped = alg.agree_upon_key_and_wrap_cek( + enc, protected, key, sender_key, epk, cek, tag + ) + ek = wrapped["ek"] + + # step 8: build resulting message + return b".".join( + [ + protected_segment, + urlsafe_b64encode(ek), + urlsafe_b64encode(iv), + urlsafe_b64encode(ciphertext), + urlsafe_b64encode(tag), + ] + ) + + def serialize_json(self, header_obj, payload, keys, sender_key=None): # noqa: C901 + """Generate a JWE JSON Serialization (in fully general syntax). + + The JWE JSON Serialization represents encrypted content as a JSON + object. This representation is neither optimized for compactness nor + URL safe. + + The following members are defined for use in top-level JSON objects + used for the fully general JWE JSON Serialization syntax: + + protected + The "protected" member MUST be present and contain the value + BASE64URL(UTF8(JWE Protected Header)) when the JWE Protected + Header value is non-empty; otherwise, it MUST be absent. These + Header Parameter values are integrity protected. + + unprotected + The "unprotected" member MUST be present and contain the value JWE + Shared Unprotected Header when the JWE Shared Unprotected Header + value is non-empty; otherwise, it MUST be absent. This value is + represented as an unencoded JSON object, rather than as a string. + These Header Parameter values are not integrity protected. + + iv + The "iv" member MUST be present and contain the value + BASE64URL(JWE Initialization Vector) when the JWE Initialization + Vector value is non-empty; otherwise, it MUST be absent. + + aad + The "aad" member MUST be present and contain the value + BASE64URL(JWE AAD)) when the JWE AAD value is non-empty; + otherwise, it MUST be absent. A JWE AAD value can be included to + supply a base64url-encoded value to be integrity protected but not + encrypted. + + ciphertext + The "ciphertext" member MUST be present and contain the value + BASE64URL(JWE Ciphertext). + + tag + The "tag" member MUST be present and contain the value + BASE64URL(JWE Authentication Tag) when the JWE Authentication Tag + value is non-empty; otherwise, it MUST be absent. + + recipients + The "recipients" member value MUST be an array of JSON objects. + Each object contains information specific to a single recipient. + This member MUST be present with exactly one array element per + recipient, even if some or all of the array element values are the + empty JSON object "{}" (which can happen when all Header Parameter + values are shared between all recipients and when no encrypted key + is used, such as when doing Direct Encryption). + + The following members are defined for use in the JSON objects that + are elements of the "recipients" array: + + header + The "header" member MUST be present and contain the value JWE Per- + Recipient Unprotected Header when the JWE Per-Recipient + Unprotected Header value is non-empty; otherwise, it MUST be + absent. This value is represented as an unencoded JSON object, + rather than as a string. These Header Parameter values are not + integrity protected. + + encrypted_key + The "encrypted_key" member MUST be present and contain the value + BASE64URL(JWE Encrypted Key) when the JWE Encrypted Key value is + non-empty; otherwise, it MUST be absent. + + This implementation assumes that "alg" and "enc" header fields are + contained in the protected or shared unprotected header. + + :param header_obj: A dict of headers (in addition optionally contains JWE AAD) + :param payload: Payload (bytes or a value convertible to bytes) + :param keys: Public keys (or a single public key) used to encrypt payload + :param sender_key: Sender's private key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: JWE JSON serialization (in fully general syntax) as dict + + Example of `header_obj`:: + + { + "protected": { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + }, + "unprotected": {"jku": "https://alice.example.com/keys.jwks"}, + "recipients": [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ], + "aad": b"Authenticate me too.", + } + """ + if not isinstance(keys, list): # single key + keys = [keys] + + if not keys: + raise ValueError("No keys have been provided") + + header_obj = deepcopy(header_obj) + + shared_header = JWESharedHeader.from_dict(header_obj) + + recipients = header_obj.get("recipients") + if recipients is None: + recipients = [{} for _ in keys] + for i in range(len(recipients)): + if recipients[i] is None: + recipients[i] = {} + if "header" not in recipients[i]: + recipients[i]["header"] = {} + + jwe_aad = header_obj.get("aad") + + if len(keys) != len(recipients): + raise ValueError( + f"Count of recipient keys {len(keys)} does not equal to count of recipients {len(recipients)}" + ) + + # step 1: Prepare algorithms & key + alg = self.get_header_alg(shared_header) + enc = self.get_header_enc(shared_header) + zip_alg = self.get_header_zip(shared_header) + + self._validate_sender_key(sender_key, alg) + self._validate_private_headers(shared_header, alg) + for recipient in recipients: + self._validate_private_headers(recipient["header"], alg) + + for i in range(len(keys)): + keys[i] = prepare_key(alg, recipients[i]["header"], keys[i]) + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + # self._post_validate_header(protected, algorithm) + + # step 2: Generate a random Content Encryption Key (CEK) + # use enc_alg.generate_cek() in scope of upcoming .wrap + # or .generate_keys_and_prepare_headers call + + # step 3: Encrypt the CEK with the recipient's public key + preset = alg.generate_preset(enc, keys[0]) + if "cek" in preset: + cek = preset["cek"] + else: + cek = None + if len(keys) > 1 and cek is None: + raise InvalidAlgorithmForMultipleRecipientsMode(alg.name) + if "header" in preset: + shared_header.update_protected(preset["header"]) + + if ( + isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) + and alg.key_size is not None + ): + # For a JWE algorithm with tag-aware key agreement in case key agreement + # with key wrapping mode is used: + # Defer key agreement with key wrapping until authentication tag is computed + epks = [] + for i in range(len(keys)): + prep = alg.generate_keys_and_prepare_headers( + enc, keys[i], sender_key, preset + ) + if cek is None: + cek = prep["cek"] + epks.append(prep["epk"]) + recipients[i]["header"].update(prep["header"]) + else: + # In any other case: + # Keep the normal steps order defined by RFC 7516 + for i in range(len(keys)): + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + wrapped = alg.wrap(enc, shared_header, keys[i], sender_key, preset) + else: + wrapped = alg.wrap(enc, shared_header, keys[i], preset) + if cek is None: + cek = wrapped["cek"] + recipients[i]["encrypted_key"] = wrapped["ek"] + if "header" in wrapped: + recipients[i]["header"].update(wrapped["header"]) + + # step 4: Generate a random JWE Initialization Vector + iv = enc.generate_iv() + + # step 5: Compute the Encoded Protected Header value + # BASE64URL(UTF8(JWE Protected Header)). If the JWE Protected Header + # is not present, let this value be the empty string. + # Let the Additional Authenticated Data encryption parameter be + # ASCII(Encoded Protected Header). However, if a JWE AAD value is + # present, instead let the Additional Authenticated Data encryption + # parameter be ASCII(Encoded Protected Header || '.' || BASE64URL(JWE AAD)). + aad = ( + json_b64encode(shared_header.protected) if shared_header.protected else b"" + ) + if jwe_aad is not None: + aad += b"." + urlsafe_b64encode(jwe_aad) + aad = to_bytes(aad, "ascii") + + # step 6: compress message if required + if zip_alg: + msg = zip_alg.compress(to_bytes(payload)) + else: + msg = to_bytes(payload) + + # step 7: perform encryption + ciphertext, tag = enc.encrypt(msg, aad, iv, cek) + + if ( + isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) + and alg.key_size is not None + ): + # For a JWE algorithm with tag-aware key agreement in case key agreement + # with key wrapping mode is used: + # Perform key agreement with key wrapping deferred at step 3 + for i in range(len(keys)): + wrapped = alg.agree_upon_key_and_wrap_cek( + enc, shared_header, keys[i], sender_key, epks[i], cek, tag + ) + recipients[i]["encrypted_key"] = wrapped["ek"] + + # step 8: build resulting message + obj = OrderedDict() + + if shared_header.protected: + obj["protected"] = to_unicode(json_b64encode(shared_header.protected)) + + if shared_header.unprotected: + obj["unprotected"] = shared_header.unprotected + + for recipient in recipients: + if not recipient["header"]: + del recipient["header"] + recipient["encrypted_key"] = to_unicode( + urlsafe_b64encode(recipient["encrypted_key"]) + ) + for member in set(recipient.keys()): + if member not in {"header", "encrypted_key"}: + del recipient[member] + obj["recipients"] = recipients + + if jwe_aad is not None: + obj["aad"] = to_unicode(urlsafe_b64encode(jwe_aad)) + + obj["iv"] = to_unicode(urlsafe_b64encode(iv)) + + obj["ciphertext"] = to_unicode(urlsafe_b64encode(ciphertext)) + + obj["tag"] = to_unicode(urlsafe_b64encode(tag)) + + return obj + + def serialize(self, header, payload, key, sender_key=None): + """Generate a JWE Serialization. + + It will automatically generate a compact or JSON serialization depending + on `header` argument. If `header` is a dict with "protected", + "unprotected" and/or "recipients" keys, it will call `serialize_json`, + otherwise it will call `serialize_compact`. + + :param header: A dict of header(s) + :param payload: Payload (bytes or a value convertible to bytes) + :param key: Public key(s) used to encrypt payload + :param sender_key: Sender's private key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: JWE compact serialization as bytes or + JWE JSON serialization as dict + """ + if "protected" in header or "unprotected" in header or "recipients" in header: + return self.serialize_json(header, payload, key, sender_key) + + return self.serialize_compact(header, payload, key, sender_key) + + def deserialize_compact(self, s, key, decode=None, sender_key=None): + """Extract JWE Compact Serialization. + + :param s: JWE Compact Serialization as bytes + :param key: Private key used to decrypt payload + (optionally can be a tuple of kid and essentially key) + :param decode: Function to decode payload data + :param sender_key: Sender's public key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: dict with `header` and `payload` keys where `header` value is + a dict containing protected header fields """ try: s = to_bytes(s) - protected_s, ek_s, iv_s, ciphertext_s, tag_s = s.rsplit(b'.') - except ValueError: - raise DecodeError('Not enough segments') + protected_s, ek_s, iv_s, ciphertext_s, tag_s = s.rsplit(b".") + except ValueError as exc: + raise DecodeError("Not enough segments") from exc protected = extract_header(protected_s, DecodeError) - ek = extract_segment(ek_s, DecodeError, 'encryption key') - iv = extract_segment(iv_s, DecodeError, 'initialization vector') - ciphertext = extract_segment(ciphertext_s, DecodeError, 'ciphertext') - tag = extract_segment(tag_s, DecodeError, 'authentication tag') + ek = extract_segment(ek_s, DecodeError, "encryption key") + iv = extract_segment(iv_s, DecodeError, "initialization vector") + ciphertext = extract_segment(ciphertext_s, DecodeError, "ciphertext") + tag = extract_segment(tag_s, DecodeError, "authentication tag") alg = self.get_header_alg(protected) enc = self.get_header_enc(protected) zip_alg = self.get_header_zip(protected) + + self._validate_sender_key(sender_key, alg) self._validate_private_headers(protected, alg) + if isinstance(key, tuple) and len(key) == 2: + # Ignore separately provided kid, extract essentially key only + key = key[1] + key = prepare_key(alg, protected, key) - cek = alg.unwrap(enc, ek, protected, key) - aad = to_bytes(protected_s, 'ascii') + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + # For a JWE algorithm with tag-aware key agreement: + if alg.key_size is not None: + # In case key agreement with key wrapping mode is used: + # Provide authentication tag to .unwrap method + cek = alg.unwrap(enc, ek, protected, key, sender_key, tag) + else: + # Otherwise, don't provide authentication tag to .unwrap method + cek = alg.unwrap(enc, ek, protected, key, sender_key) + else: + # For any other JWE algorithm: + # Don't provide authentication tag to .unwrap method + cek = alg.unwrap(enc, ek, protected, key) + + aad = to_bytes(protected_s, "ascii") + msg = enc.decrypt(ciphertext, aad, iv, tag, cek) + + if zip_alg: + payload = zip_alg.decompress(to_bytes(msg)) + else: + payload = msg + + if decode: + payload = decode(payload) + return {"header": protected, "payload": payload} + + def deserialize_json(self, obj, key, decode=None, sender_key=None): # noqa: C901 + """Extract JWE JSON Serialization. + + :param obj: JWE JSON Serialization as dict or str + :param key: Private key used to decrypt payload + (optionally can be a tuple of kid and essentially key) + :param decode: Function to decode payload data + :param sender_key: Sender's public key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: dict with `header` and `payload` keys where `header` value is + a dict containing `protected`, `unprotected`, `recipients` and/or + `aad` keys + """ + obj = ensure_dict(obj, "JWE") + obj = deepcopy(obj) + + if "protected" in obj: + protected = extract_header(to_bytes(obj["protected"]), DecodeError) + else: + protected = None + + unprotected = obj.get("unprotected") + + recipients = obj["recipients"] + for recipient in recipients: + if "header" not in recipient: + recipient["header"] = {} + recipient["encrypted_key"] = extract_segment( + to_bytes(recipient["encrypted_key"]), DecodeError, "encrypted key" + ) + + if "aad" in obj: + jwe_aad = extract_segment(to_bytes(obj["aad"]), DecodeError, "JWE AAD") + else: + jwe_aad = None + + iv = extract_segment(to_bytes(obj["iv"]), DecodeError, "initialization vector") + + ciphertext = extract_segment( + to_bytes(obj["ciphertext"]), DecodeError, "ciphertext" + ) + + tag = extract_segment(to_bytes(obj["tag"]), DecodeError, "authentication tag") + + shared_header = JWESharedHeader(protected, unprotected) + + alg = self.get_header_alg(shared_header) + enc = self.get_header_enc(shared_header) + zip_alg = self.get_header_zip(shared_header) + + self._validate_sender_key(sender_key, alg) + self._validate_private_headers(shared_header, alg) + for recipient in recipients: + self._validate_private_headers(recipient["header"], alg) + + kid = None + if isinstance(key, tuple) and len(key) == 2: + # Extract separately provided kid and essentially key + kid = key[0] + key = key[1] + + key = alg.prepare_key(key) + + if kid is None: + # If kid has not been provided separately, try to get it from key itself + kid = key.kid + + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + def _unwrap_with_sender_key_and_tag(ek, header): + return alg.unwrap(enc, ek, header, key, sender_key, tag) + + def _unwrap_with_sender_key_and_without_tag(ek, header): + return alg.unwrap(enc, ek, header, key, sender_key) + + def _unwrap_without_sender_key_and_tag(ek, header): + return alg.unwrap(enc, ek, header, key) + + def _unwrap_for_matching_recipient(unwrap_func): + if kid is not None: + for recipient in recipients: + if recipient["header"].get("kid") == kid: + header = JWEHeader(protected, unprotected, recipient["header"]) + return unwrap_func(recipient["encrypted_key"], header) + + # Since no explicit match has been found, iterate over all the recipients + error = None + for recipient in recipients: + header = JWEHeader(protected, unprotected, recipient["header"]) + try: + return unwrap_func(recipient["encrypted_key"], header) + except Exception as e: + error = e + else: + if error is None: + raise KeyMismatchError() + else: + raise error + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + # For a JWE algorithm with tag-aware key agreement: + if alg.key_size is not None: + # In case key agreement with key wrapping mode is used: + # Provide authentication tag to .unwrap method + cek = _unwrap_for_matching_recipient(_unwrap_with_sender_key_and_tag) + else: + # Otherwise, don't provide authentication tag to .unwrap method + cek = _unwrap_for_matching_recipient( + _unwrap_with_sender_key_and_without_tag + ) + else: + # For any other JWE algorithm: + # Don't provide authentication tag to .unwrap method + cek = _unwrap_for_matching_recipient(_unwrap_without_sender_key_and_tag) + + aad = to_bytes(obj.get("protected", "")) + if "aad" in obj: + aad += b"." + to_bytes(obj["aad"]) + aad = to_bytes(aad, "ascii") + msg = enc.decrypt(ciphertext, aad, iv, tag, cek) if zip_alg: @@ -150,38 +640,108 @@ def deserialize_compact(self, s, key, decode=None): if decode: payload = decode(payload) - return {'header': protected, 'payload': payload} + + for recipient in recipients: + if not recipient["header"]: + del recipient["header"] + for member in set(recipient.keys()): + if member != "header": + del recipient[member] + + header = {} + if protected: + header["protected"] = protected + if unprotected: + header["unprotected"] = unprotected + header["recipients"] = recipients + if jwe_aad is not None: + header["aad"] = jwe_aad + + return {"header": header, "payload": payload} + + def deserialize(self, obj, key, decode=None, sender_key=None): + """Extract a JWE Serialization. + + It supports both compact and JSON serialization. + + :param obj: JWE compact serialization as bytes or + JWE JSON serialization as dict or str + :param key: Private key used to decrypt payload + (optionally can be a tuple of kid and essentially key) + :param decode: Function to decode payload data + :param sender_key: Sender's public key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: dict with `header` and `payload` keys + """ + if isinstance(obj, dict): + return self.deserialize_json(obj, key, decode, sender_key) + + obj = to_bytes(obj) + if obj.startswith(b"{") and obj.endswith(b"}"): + return self.deserialize_json(obj, key, decode, sender_key) + + return self.deserialize_compact(obj, key, decode, sender_key) + + @staticmethod + def parse_json(obj): + """Parse JWE JSON Serialization. + + :param obj: JWE JSON Serialization as str or dict + :return: Parsed JWE JSON Serialization as dict if `obj` is an str, + or `obj` as is if `obj` is already a dict + """ + return ensure_dict(obj, "JWE") def get_header_alg(self, header): - if 'alg' not in header: + if "alg" not in header: raise MissingAlgorithmError() - alg = header['alg'] - if self._algorithms and alg not in self._algorithms: - raise UnsupportedAlgorithmError() + alg = header["alg"] if alg not in self.ALG_REGISTRY: raise UnsupportedAlgorithmError() - return self.ALG_REGISTRY[alg] + + instance = self.ALG_REGISTRY[alg] + + # use all ALG_REGISTRY algorithms + if self._algorithms is None: + # do not use deprecated algorithms + if instance.deprecated: + raise UnsupportedAlgorithmError() + elif alg not in self._algorithms: + raise UnsupportedAlgorithmError() + return instance def get_header_enc(self, header): - if 'enc' not in header: + if "enc" not in header: raise MissingEncryptionAlgorithmError() - enc = header['enc'] - if self._algorithms and enc not in self._algorithms: + enc = header["enc"] + if self._algorithms is not None and enc not in self._algorithms: raise UnsupportedEncryptionAlgorithmError() if enc not in self.ENC_REGISTRY: raise UnsupportedEncryptionAlgorithmError() return self.ENC_REGISTRY[enc] def get_header_zip(self, header): - if 'zip' in header: - z = header['zip'] - if self._algorithms and z not in self._algorithms: + if "zip" in header: + z = header["zip"] + if self._algorithms is not None and z not in self._algorithms: raise UnsupportedCompressionAlgorithmError() if z not in self.ZIP_REGISTRY: raise UnsupportedCompressionAlgorithmError() return self.ZIP_REGISTRY[z] + def _validate_sender_key(self, sender_key, alg): + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + if sender_key is None: + raise ValueError( + f"{alg.name} algorithm requires sender_key but passed sender_key value is None" + ) + else: + if sender_key is not None: + raise ValueError( + f"{alg.name} algorithm does not use sender_key but passed sender_key value is not None" + ) + def _validate_private_headers(self, header, alg): # only validate private headers when developers set # private headers explicitly @@ -196,12 +756,10 @@ def _validate_private_headers(self, header, alg): for k in header: if k not in names: - raise InvalidHeaderParameterName(k) + raise InvalidHeaderParameterNameError(k) def prepare_key(alg, header, key): if callable(key): key = key(header, None) - elif 'jwk' in header: - key = header['jwk'] return alg.prepare_key(key) diff --git a/authlib/jose/rfc7516/models.py b/authlib/jose/rfc7516/models.py index 5eab89c7a..ce98257f3 100644 --- a/authlib/jose/rfc7516/models.py +++ b/authlib/jose/rfc7516/models.py @@ -1,32 +1,64 @@ import os +from abc import ABCMeta -class JWEAlgorithm(object): - """Interface for JWE algorithm. JWA specification (RFC7518) SHOULD - implement the algorithms for JWE with this base implementation. - """ +class JWEAlgorithmBase(metaclass=ABCMeta): # noqa: B024 + """Base interface for all JWE algorithms.""" + EXTRA_HEADERS = None name = None description = None - algorithm_type = 'JWE' - algorithm_location = 'alg' + deprecated = False + algorithm_type = "JWE" + algorithm_location = "alg" def prepare_key(self, raw_data): raise NotImplementedError - def wrap(self, enc_alg, headers, key): + def generate_preset(self, enc_alg, key): + raise NotImplementedError + + +class JWEAlgorithm(JWEAlgorithmBase, metaclass=ABCMeta): + """Interface for JWE algorithm conforming to RFC7518. + JWA specification (RFC7518) SHOULD implement the algorithms for JWE + with this base implementation. + """ + + def wrap(self, enc_alg, headers, key, preset=None): raise NotImplementedError def unwrap(self, enc_alg, ek, headers, key): raise NotImplementedError -class JWEEncAlgorithm(object): +class JWEAlgorithmWithTagAwareKeyAgreement(JWEAlgorithmBase, metaclass=ABCMeta): + """Interface for JWE algorithm with tag-aware key agreement (in key agreement + with key wrapping mode). + ECDH-1PU is an example of such an algorithm. + """ + + def generate_keys_and_prepare_headers(self, enc_alg, key, sender_key, preset=None): + raise NotImplementedError + + def agree_upon_key_and_wrap_cek( + self, enc_alg, headers, key, sender_key, epk, cek, tag + ): + raise NotImplementedError + + def wrap(self, enc_alg, headers, key, sender_key, preset=None): + raise NotImplementedError + + def unwrap(self, enc_alg, ek, headers, key, sender_key, tag=None): + raise NotImplementedError + + +class JWEEncAlgorithm: name = None description = None - algorithm_type = 'JWE' - algorithm_location = 'enc' + algorithm_type = "JWE" + algorithm_location = "enc" IV_SIZE = None CEK_SIZE = None @@ -48,7 +80,7 @@ def encrypt(self, msg, aad, iv, key): :param aad: additional authenticated data in bytes :param iv: initialization vector in bytes :param key: encrypted key in bytes - :return: (ciphertext, iv, tag) + :return: (ciphertext, tag) """ raise NotImplementedError @@ -65,14 +97,62 @@ def decrypt(self, ciphertext, aad, iv, tag, key): raise NotImplementedError -class JWEZipAlgorithm(object): +class JWEZipAlgorithm: name = None description = None - algorithm_type = 'JWE' - algorithm_location = 'zip' + algorithm_type = "JWE" + algorithm_location = "zip" def compress(self, s): raise NotImplementedError def decompress(self, s): raise NotImplementedError + + +class JWESharedHeader(dict): + """Shared header object for JWE. + + Combines protected header and shared unprotected header together. + """ + + def __init__(self, protected, unprotected): + obj = {} + if unprotected: + obj.update(unprotected) + if protected: + obj.update(protected) + super().__init__(obj) + self.protected = protected if protected else {} + self.unprotected = unprotected if unprotected else {} + + def update_protected(self, addition): + self.update(addition) + self.protected.update(addition) + + @classmethod + def from_dict(cls, obj): + if isinstance(obj, cls): + return obj + return cls(obj.get("protected"), obj.get("unprotected")) + + +class JWEHeader(dict): + """Header object for JWE. + + Combines protected header, shared unprotected header + and specific recipient's unprotected header together. + """ + + def __init__(self, protected, unprotected, header): + obj = {} + if unprotected: + obj.update(unprotected) + if header: + obj.update(header) + if protected: + obj.update(protected) + super().__init__(obj) + self.protected = protected if protected else {} + self.unprotected = unprotected if unprotected else {} + self.header = header if header else {} diff --git a/authlib/jose/rfc7517/__init__.py b/authlib/jose/rfc7517/__init__.py index 079a7ccc4..2f41e3b54 100644 --- a/authlib/jose/rfc7517/__init__.py +++ b/authlib/jose/rfc7517/__init__.py @@ -1,13 +1,16 @@ -""" - authlib.jose.rfc7517 - ~~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.rfc7517. +~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - JSON Web Key (JWK). +This module represents a direct implementation of +JSON Web Key (JWK). - https://tools.ietf.org/html/rfc7517 +https://tools.ietf.org/html/rfc7517 """ -from .models import Key, KeySet +from ._cryptography_key import load_pem_key +from .asymmetric_key import AsymmetricKey +from .base_key import Key +from .jwk import JsonWebKey +from .key_set import KeySet -__all__ = ['Key', 'KeySet'] +__all__ = ["Key", "AsymmetricKey", "KeySet", "JsonWebKey", "load_pem_key"] diff --git a/authlib/jose/rfc7517/_cryptography_key.py b/authlib/jose/rfc7517/_cryptography_key.py new file mode 100644 index 000000000..ad16e9e53 --- /dev/null +++ b/authlib/jose/rfc7517/_cryptography_key.py @@ -0,0 +1,35 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.hazmat.primitives.serialization import load_pem_public_key +from cryptography.hazmat.primitives.serialization import load_ssh_public_key +from cryptography.x509 import load_pem_x509_certificate + +from authlib.common.encoding import to_bytes + + +def load_pem_key(raw, ssh_type=None, key_type=None, password=None): + raw = to_bytes(raw) + + if ssh_type and raw.startswith(ssh_type): + return load_ssh_public_key(raw, backend=default_backend()) + + if key_type == "public": + return load_pem_public_key(raw, backend=default_backend()) + + if key_type == "private" or password is not None: + return load_pem_private_key(raw, password=password, backend=default_backend()) + + if b"PUBLIC" in raw: + return load_pem_public_key(raw, backend=default_backend()) + + if b"PRIVATE" in raw: + return load_pem_private_key(raw, password=password, backend=default_backend()) + + if b"CERTIFICATE" in raw: + cert = load_pem_x509_certificate(raw, default_backend()) + return cert.public_key() + + try: + return load_pem_private_key(raw, password=password, backend=default_backend()) + except ValueError: + return load_pem_public_key(raw, backend=default_backend()) diff --git a/authlib/jose/rfc7517/asymmetric_key.py b/authlib/jose/rfc7517/asymmetric_key.py new file mode 100644 index 000000000..571c851eb --- /dev/null +++ b/authlib/jose/rfc7517/asymmetric_key.py @@ -0,0 +1,196 @@ +from cryptography.hazmat.primitives.serialization import BestAvailableEncryption +from cryptography.hazmat.primitives.serialization import Encoding +from cryptography.hazmat.primitives.serialization import NoEncryption +from cryptography.hazmat.primitives.serialization import PrivateFormat +from cryptography.hazmat.primitives.serialization import PublicFormat + +from authlib.common.encoding import to_bytes + +from ._cryptography_key import load_pem_key +from .base_key import Key + + +class AsymmetricKey(Key): + """This is the base class for a JSON Web Key.""" + + PUBLIC_KEY_FIELDS = [] + PRIVATE_KEY_FIELDS = [] + PRIVATE_KEY_CLS = bytes + PUBLIC_KEY_CLS = bytes + SSH_PUBLIC_PREFIX = b"" + + def __init__(self, private_key=None, public_key=None, options=None): + super().__init__(options) + self.private_key = private_key + self.public_key = public_key + + @property + def public_only(self): + if self.private_key: + return False + if "d" in self.tokens: + return False + return True + + def get_op_key(self, operation): + """Get the raw key for the given key_op. This method will also + check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :return: raw key + """ + self.check_key_op(operation) + if operation in self.PUBLIC_KEY_OPS: + return self.get_public_key() + return self.get_private_key() + + def get_public_key(self): + if self.public_key: + return self.public_key + + private_key = self.get_private_key() + if private_key: + return private_key.public_key() + + return self.public_key + + def get_private_key(self): + if self.private_key: + return self.private_key + + if self.tokens: + self.load_raw_key() + return self.private_key + + def load_raw_key(self): + if "d" in self.tokens: + self.private_key = self.load_private_key() + else: + self.public_key = self.load_public_key() + + def load_dict_key(self): + if self.private_key: + self._dict_data.update(self.dumps_private_key()) + else: + self._dict_data.update(self.dumps_public_key()) + + def dumps_private_key(self): + raise NotImplementedError() + + def dumps_public_key(self): + raise NotImplementedError() + + def load_private_key(self): + raise NotImplementedError() + + def load_public_key(self): + raise NotImplementedError() + + def as_dict(self, is_private=False, **params): + """Represent this key as a dict of the JSON Web Key.""" + tokens = self.tokens + if is_private and "d" not in tokens: + raise ValueError("This is a public key") + + kid = tokens.get("kid") + if "d" in tokens and not is_private: + # filter out private fields + tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS} + tokens["kty"] = self.kty + if kid: + tokens["kid"] = kid + + if not kid: + tokens["kid"] = self.thumbprint() + + tokens.update(params) + return tokens + + def as_key(self, is_private=False): + """Represent this key as raw key.""" + if is_private: + return self.get_private_key() + return self.get_public_key() + + def as_bytes(self, encoding=None, is_private=False, password=None): + """Export key into PEM/DER format bytes. + + :param encoding: "PEM" or "DER" + :param is_private: export private key or public key + :param password: encrypt private key with password + :return: bytes + """ + if encoding is None or encoding == "PEM": + encoding = Encoding.PEM + elif encoding == "DER": + encoding = Encoding.DER + else: + raise ValueError(f"Invalid encoding: {encoding!r}") + + raw_key = self.as_key(is_private) + if is_private: + if not raw_key: + raise ValueError("This is a public key") + if password is None: + encryption_algorithm = NoEncryption() + else: + encryption_algorithm = BestAvailableEncryption(to_bytes(password)) + return raw_key.private_bytes( + encoding=encoding, + format=PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) + return raw_key.public_bytes( + encoding=encoding, + format=PublicFormat.SubjectPublicKeyInfo, + ) + + def as_pem(self, is_private=False, password=None): + return self.as_bytes(is_private=is_private, password=password) + + def as_der(self, is_private=False, password=None): + return self.as_bytes(encoding="DER", is_private=is_private, password=password) + + @classmethod + def import_dict_key(cls, raw, options=None): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + return key + + @classmethod + def import_key(cls, raw, options=None): + if isinstance(raw, cls): + if options is not None: + raw.options.update(options) + return raw + + if isinstance(raw, cls.PUBLIC_KEY_CLS): + key = cls(public_key=raw, options=options) + elif isinstance(raw, cls.PRIVATE_KEY_CLS): + key = cls(private_key=raw, options=options) + elif isinstance(raw, dict): + key = cls.import_dict_key(raw, options) + else: + if options is not None: + password = options.pop("password", None) + else: + password = None + raw_key = load_pem_key(raw, cls.SSH_PUBLIC_PREFIX, password=password) + if isinstance(raw_key, cls.PUBLIC_KEY_CLS): + key = cls(public_key=raw_key, options=options) + elif isinstance(raw_key, cls.PRIVATE_KEY_CLS): + key = cls(private_key=raw_key, options=options) + else: + raise ValueError("Invalid data for importing key") + return key + + @classmethod + def validate_raw_key(cls, key): + return isinstance(key, cls.PUBLIC_KEY_CLS) or isinstance( + key, cls.PRIVATE_KEY_CLS + ) + + @classmethod + def generate_key(cls, crv_or_size, options=None, is_private=False): + raise NotImplementedError() diff --git a/authlib/jose/rfc7517/base_key.py b/authlib/jose/rfc7517/base_key.py new file mode 100644 index 000000000..0baa62c60 --- /dev/null +++ b/authlib/jose/rfc7517/base_key.py @@ -0,0 +1,120 @@ +import hashlib +from collections import OrderedDict + +from authlib.common.encoding import json_dumps +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64encode + +from ..errors import InvalidUseError + + +class Key: + """This is the base class for a JSON Web Key.""" + + kty = "_" + + ALLOWED_PARAMS = ["use", "key_ops", "alg", "kid", "x5u", "x5c", "x5t", "x5t#S256"] + + PRIVATE_KEY_OPS = [ + "sign", + "decrypt", + "unwrapKey", + ] + PUBLIC_KEY_OPS = [ + "verify", + "encrypt", + "wrapKey", + ] + + REQUIRED_JSON_FIELDS = [] + + def __init__(self, options=None): + self.options = options or {} + self._dict_data = {} + + @property + def tokens(self): + if not self._dict_data: + self.load_dict_key() + + rv = dict(self._dict_data) + rv["kty"] = self.kty + for k in self.ALLOWED_PARAMS: + if k not in rv and k in self.options: + rv[k] = self.options[k] + return rv + + @property + def kid(self): + return self.tokens.get("kid") + + def keys(self): + return self.tokens.keys() + + def __getitem__(self, item): + return self.tokens[item] + + @property + def public_only(self): + raise NotImplementedError() + + def load_raw_key(self): + raise NotImplementedError() + + def load_dict_key(self): + raise NotImplementedError() + + def check_key_op(self, operation): + """Check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :raise: ValueError + """ + key_ops = self.tokens.get("key_ops") + if key_ops is not None and operation not in key_ops: + raise ValueError(f'Unsupported key_op "{operation}"') + + if operation in self.PRIVATE_KEY_OPS and self.public_only: + raise ValueError(f'Invalid key_op "{operation}" for public key') + + use = self.tokens.get("use") + if use: + if operation in ["sign", "verify"]: + if use != "sig": + raise InvalidUseError() + elif operation in ["decrypt", "encrypt", "wrapKey", "unwrapKey"]: + if use != "enc": + raise InvalidUseError() + + def as_dict(self, is_private=False, **params): + raise NotImplementedError() + + def as_json(self, is_private=False, **params): + """Represent this key as a JSON string.""" + obj = self.as_dict(is_private, **params) + return json_dumps(obj) + + def thumbprint(self): + """Implementation of RFC7638 JSON Web Key (JWK) Thumbprint.""" + fields = list(self.REQUIRED_JSON_FIELDS) + fields.append("kty") + fields.sort() + data = OrderedDict() + + for k in fields: + data[k] = self.tokens[k] + + json_data = json_dumps(data) + digest_data = hashlib.sha256(to_bytes(json_data)).digest() + return to_unicode(urlsafe_b64encode(digest_data)) + + @classmethod + def check_required_fields(cls, data): + for k in cls.REQUIRED_JSON_FIELDS: + if k not in data: + raise ValueError(f'Missing required field: "{k}"') + + @classmethod + def validate_raw_key(cls, key): + raise NotImplementedError() diff --git a/authlib/jose/rfc7517/jwk.py b/authlib/jose/rfc7517/jwk.py new file mode 100644 index 000000000..034691d2a --- /dev/null +++ b/authlib/jose/rfc7517/jwk.py @@ -0,0 +1,64 @@ +from authlib.common.encoding import json_loads + +from ._cryptography_key import load_pem_key +from .key_set import KeySet + + +class JsonWebKey: + JWK_KEY_CLS = {} + + @classmethod + def generate_key(cls, kty, crv_or_size, options=None, is_private=False): + """Generate a Key with the given key type, curve name or bit size. + + :param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP`` + :param crv_or_size: curve name or bit size + :param options: a dict of other options for Key + :param is_private: create a private key or public key + :return: Key instance + """ + key_cls = cls.JWK_KEY_CLS[kty] + return key_cls.generate_key(crv_or_size, options, is_private) + + @classmethod + def import_key(cls, raw, options=None): + """Import a Key from bytes, string, PEM or dict. + + :return: Key instance + """ + kty = None + if options is not None: + kty = options.get("kty") + + if kty is None and isinstance(raw, dict): + kty = raw.get("kty") + + if kty is None: + raw_key = load_pem_key(raw) + for _kty in cls.JWK_KEY_CLS: + key_cls = cls.JWK_KEY_CLS[_kty] + if key_cls.validate_raw_key(raw_key): + return key_cls.import_key(raw_key, options) + + key_cls = cls.JWK_KEY_CLS[kty] + return key_cls.import_key(raw, options) + + @classmethod + def import_key_set(cls, raw): + """Import KeySet from string, dict or a list of keys. + + :return: KeySet instance + """ + raw = _transform_raw_key(raw) + if isinstance(raw, dict) and "keys" in raw: + keys = raw.get("keys") + return KeySet([cls.import_key(k) for k in keys]) + raise ValueError("Invalid key set format") + + +def _transform_raw_key(raw): + if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"): + return json_loads(raw) + elif isinstance(raw, (tuple, list)): + return {"keys": raw} + return raw diff --git a/authlib/jose/rfc7517/key_set.py b/authlib/jose/rfc7517/key_set.py new file mode 100644 index 000000000..bd8fa691c --- /dev/null +++ b/authlib/jose/rfc7517/key_set.py @@ -0,0 +1,53 @@ +from authlib.common.encoding import json_dumps + + +class KeySet: + """This class represents a JSON Web Key Set.""" + + def __init__(self, keys): + self.keys = keys + + def as_dict(self, is_private=False, **params): + """Represent this key as a dict of the JSON Web Key Set.""" + return {"keys": [k.as_dict(is_private, **params) for k in self.keys]} + + def as_json(self, is_private=False, **params): + """Represent this key set as a JSON string.""" + obj = self.as_dict(is_private, **params) + return json_dumps(obj) + + def find_by_kid(self, kid, **params): + """Find the key matches the given kid value. + + :param kid: A string of kid + :return: Key instance + :raise: ValueError + """ + # Proposed fix, feel free to do something else but the idea is that we take the only key + # of the set if no kid is specified + if kid is None and len(self.keys) == 1: + return self.keys[0] + + keys = [key for key in self.keys if key.kid == kid] + if params: + keys = list(_filter_keys_by_params(keys, **params)) + + if keys: + return keys[0] + raise ValueError("Key not found") + + +def _filter_keys_by_params(keys, **params): + _use = params.get("use") + _alg = params.get("alg") + + for key in keys: + designed_use = key.tokens.get("use") + if designed_use and _use and designed_use != _use: + continue + + designed_alg = key.tokens.get("alg") + if designed_alg and _alg and designed_alg != _alg: + continue + + yield key diff --git a/authlib/jose/rfc7517/models.py b/authlib/jose/rfc7517/models.py deleted file mode 100644 index b3b24f321..000000000 --- a/authlib/jose/rfc7517/models.py +++ /dev/null @@ -1,156 +0,0 @@ -import hashlib -from collections import OrderedDict -from authlib.common.encoding import ( - json_dumps, - to_bytes, - to_unicode, - urlsafe_b64encode, -) -from ..errors import InvalidUseError - - -class Key(dict): - """This is the base class for a JSON Web Key.""" - kty = '_' - - ALLOWED_PARAMS = [ - 'use', 'key_ops', 'alg', 'kid', - 'x5u', 'x5c', 'x5t', 'x5t#S256' - ] - - PRIVATE_KEY_OPS = [ - 'sign', 'decrypt', 'unwrapKey', - ] - PUBLIC_KEY_OPS = [ - 'verify', 'encrypt', 'wrapKey', - ] - - REQUIRED_JSON_FIELDS = [] - RAW_KEY_CLS = bytes - - def __init__(self, payload): - super(Key, self).__init__(payload) - - self.key_type = 'secret' - self.raw_key = None - - def get_op_key(self, operation): - """Get the raw key for the given key_op. This method will also - check if the given key_op is supported by this key. - - :param operation: key operation value, such as "sign", "encrypt". - :return: raw key - """ - self.check_key_op(operation) - if operation in self.PUBLIC_KEY_OPS: - return self.get_public_key() - return self.get_private_key() - - def get_public_key(self): - if self.key_type == 'private': - return self.raw_key.public_key() - return self.raw_key - - def get_private_key(self): - if self.key_type == 'private': - return self.raw_key - - def check_key_op(self, operation): - """Check if the given key_op is supported by this key. - - :param operation: key operation value, such as "sign", "encrypt". - :raise: ValueError - """ - key_ops = self.get('key_ops') - if key_ops is not None and operation not in key_ops: - raise ValueError('Unsupported key_op "{}"'.format(operation)) - - if operation in self.PRIVATE_KEY_OPS and self.key_type == 'public': - raise ValueError('Invalid key_op "{}" for public key'.format(operation)) - - use = self.get('use') - if use: - if operation in ['sign', 'verify']: - if use != 'sig': - raise InvalidUseError() - elif operation in ['decrypt', 'encrypt', 'wrapKey', 'unwrapKey']: - if use != 'enc': - raise InvalidUseError() - - def as_key(self): - """Represent this key as raw key.""" - return self.raw_key - - def as_dict(self, add_kid=False): - """Represent this key as a dict of the JSON Web Key.""" - obj = dict(self) - obj['kty'] = self.kty - if add_kid and 'kid' not in obj: - obj['kid'] = self.thumbprint() - return obj - - def as_json(self): - """Represent this key as a JSON string.""" - obj = self.as_dict() - return json_dumps(obj) - - def as_pem(self): - """Represent this key as string in PEM format.""" - raise RuntimeError('Not supported') - - def thumbprint(self): - """Implementation of RFC7638 JSON Web Key (JWK) Thumbprint.""" - fields = list(self.REQUIRED_JSON_FIELDS) - fields.append('kty') - fields.sort() - data = OrderedDict() - - obj = self.as_dict() - for k in fields: - data[k] = obj[k] - - json_data = json_dumps(data) - digest_data = hashlib.sha256(to_bytes(json_data)).digest() - return to_unicode(urlsafe_b64encode(digest_data)) - - @classmethod - def check_required_fields(cls, data): - for k in cls.REQUIRED_JSON_FIELDS: - if k not in data: - raise ValueError('Missing required field: "{}"'.format(k)) - - @classmethod - def generate_key(cls, crv_or_size, options=None, is_private=False): - raise NotImplementedError() - - @classmethod - def import_key(cls, raw, options=None): - raise NotImplementedError() - - -class KeySet(object): - """This class represents a JSON Web Key Set.""" - - def __init__(self, keys): - self.keys = keys - - def as_dict(self): - """Represent this key as a dict of the JSON Web Key Set.""" - return {'keys': [k.as_dict(True) for k in self.keys]} - - def as_json(self): - """Represent this key set as a JSON string.""" - obj = self.as_dict() - return json_dumps(obj) - - def find_by_kid(self, kid): - """Find the key matches the given kid value. - - :param kid: A string of kid - :return: Key instance - :raise: ValueError - """ - for k in self.keys: - if k.get('kid') == kid: - return k - raise ValueError('Invalid JSON Web Key Set') diff --git a/authlib/jose/rfc7518/__init__.py b/authlib/jose/rfc7518/__init__.py index 658760245..9b9dbcb79 100644 --- a/authlib/jose/rfc7518/__init__.py +++ b/authlib/jose/rfc7518/__init__.py @@ -1,19 +1,39 @@ -from .jws_algorithms import register_jws_rfc7518 -from .jwe_algorithms import register_jwe_rfc7518 +from .ec_key import ECKey +from .jwe_algs import JWE_ALG_ALGORITHMS +from .jwe_algs import AESAlgorithm +from .jwe_algs import ECDHESAlgorithm +from .jwe_algs import u32be_len_input +from .jwe_encs import JWE_ENC_ALGORITHMS +from .jwe_encs import CBCHS2EncAlgorithm +from .jwe_zips import DeflateZipAlgorithm +from .jws_algs import JWS_ALGORITHMS from .oct_key import OctKey -from ._cryptography_backends import ( - RSAKey, ECKey, ECDHAlgorithm, - import_key, load_pem_key, export_key, -) +from .rsa_key import RSAKey + + +def register_jws_rfc7518(cls): + for algorithm in JWS_ALGORITHMS: + cls.register_algorithm(algorithm) + + +def register_jwe_rfc7518(cls): + for algorithm in JWE_ALG_ALGORITHMS: + cls.register_algorithm(algorithm) + + for algorithm in JWE_ENC_ALGORITHMS: + cls.register_algorithm(algorithm) + + cls.register_algorithm(DeflateZipAlgorithm()) + __all__ = [ - 'register_jws_rfc7518', - 'register_jwe_rfc7518', - 'ECDHAlgorithm', - 'OctKey', - 'RSAKey', - 'ECKey', - 'import_key', - 'load_pem_key', - 'export_key', + "register_jws_rfc7518", + "register_jwe_rfc7518", + "OctKey", + "RSAKey", + "ECKey", + "u32be_len_input", + "AESAlgorithm", + "ECDHESAlgorithm", + "CBCHS2EncAlgorithm", ] diff --git a/authlib/jose/rfc7518/_cryptography_backends/__init__.py b/authlib/jose/rfc7518/_cryptography_backends/__init__.py deleted file mode 100644 index 5f8ab16db..000000000 --- a/authlib/jose/rfc7518/_cryptography_backends/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._jws import JWS_ALGORITHMS -from ._jwe_alg import JWE_ALG_ALGORITHMS, ECDHAlgorithm -from ._jwe_enc import JWE_ENC_ALGORITHMS -from ._keys import ( - RSAKey, ECKey, - load_pem_key, import_key, export_key, -) diff --git a/authlib/jose/rfc7518/_cryptography_backends/_jwe_alg.py b/authlib/jose/rfc7518/_cryptography_backends/_jwe_alg.py deleted file mode 100644 index 8d000d216..000000000 --- a/authlib/jose/rfc7518/_cryptography_backends/_jwe_alg.py +++ /dev/null @@ -1,268 +0,0 @@ -import os -import struct -from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.keywrap import ( - aes_key_wrap, - aes_key_unwrap -) -from cryptography.hazmat.primitives.ciphers import Cipher -from cryptography.hazmat.primitives.ciphers.algorithms import AES -from cryptography.hazmat.primitives.ciphers.modes import GCM -from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash -from authlib.common.encoding import ( - to_bytes, to_native, - urlsafe_b64decode, - urlsafe_b64encode -) -from authlib.jose.rfc7516 import JWEAlgorithm -from ._keys import RSAKey, ECKey -from ..oct_key import OctKey - - -class RSAAlgorithm(JWEAlgorithm): - #: A key of size 2048 bits or larger MUST be used with these algorithms - #: RSA1_5, RSA-OAEP, RSA-OAEP-256 - key_size = 2048 - - def __init__(self, name, description, pad_fn): - self.name = name - self.description = description - self.padding = pad_fn - - def prepare_key(self, raw_data): - return RSAKey.import_key(raw_data) - - def wrap(self, enc_alg, headers, key): - cek = enc_alg.generate_cek() - op_key = key.get_op_key('wrapKey') - if op_key.key_size < self.key_size: - raise ValueError('A key of size 2048 bits or larger MUST be used') - ek = op_key.encrypt(cek, self.padding) - return {'ek': ek, 'cek': cek} - - def unwrap(self, enc_alg, ek, headers, key): - # it will raise ValueError if failed - op_key = key.get_op_key('unwrapKey') - cek = op_key.decrypt(ek, self.padding) - print(cek, enc_alg.key_size) - if len(cek) * 8 != enc_alg.CEK_SIZE: - raise ValueError('Invalid "cek" length') - return cek - - -class AESAlgorithm(JWEAlgorithm): - def __init__(self, key_size): - self.name = 'A{}KW'.format(key_size) - self.description = 'AES Key Wrap using {}-bit key'.format(key_size) - self.key_size = key_size - - def prepare_key(self, raw_data): - return OctKey.import_key(raw_data) - - def _check_key(self, key): - if len(key) * 8 != self.key_size: - raise ValueError( - 'A key of size {} bits is required.'.format(self.key_size)) - - def wrap(self, enc_alg, headers, key): - cek = enc_alg.generate_cek() - op_key = key.get_op_key('wrapKey') - self._check_key(op_key) - ek = aes_key_wrap(op_key, cek, default_backend()) - return {'ek': ek, 'cek': cek} - - def unwrap(self, enc_alg, ek, headers, key): - op_key = key.get_op_key('unwrapKey') - self._check_key(op_key) - cek = aes_key_unwrap(op_key, ek, default_backend()) - if len(cek) * 8 != enc_alg.CEK_SIZE: - raise ValueError('Invalid "cek" length') - return cek - - -class AESGCMAlgorithm(JWEAlgorithm): - EXTRA_HEADERS = frozenset(['iv', 'tag']) - - def __init__(self, key_size): - self.name = 'A{}GCMKW'.format(key_size) - self.description = 'Key wrapping with AES GCM using {}-bit key'.format(key_size) - self.key_size = key_size - - def prepare_key(self, raw_data): - return OctKey.import_key(raw_data) - - def _check_key(self, key): - if len(key) * 8 != self.key_size: - raise ValueError( - 'A key of size {} bits is required.'.format(self.key_size)) - - def wrap(self, enc_alg, headers, key): - cek = enc_alg.generate_cek() - op_key = key.get_op_key('wrapKey') - self._check_key(op_key) - - #: https://tools.ietf.org/html/rfc7518#section-4.7.1.1 - #: The "iv" (initialization vector) Header Parameter value is the - #: base64url-encoded representation of the 96-bit IV value - iv_size = 96 - iv = os.urandom(iv_size // 8) - - cipher = Cipher(AES(op_key), GCM(iv), backend=default_backend()) - enc = cipher.encryptor() - ek = enc.update(cek) + enc.finalize() - - h = { - 'iv': to_native(urlsafe_b64encode(iv)), - 'tag': to_native(urlsafe_b64encode(enc.tag)) - } - return {'ek': ek, 'cek': cek, 'header': h} - - def unwrap(self, enc_alg, ek, headers, key): - op_key = key.get_op_key('unwrapKey') - self._check_key(op_key) - - iv = headers.get('iv') - if not iv: - raise ValueError('Missing "iv" in headers') - - tag = headers.get('tag') - if not tag: - raise ValueError('Missing "tag" in headers') - - iv = urlsafe_b64decode(to_bytes(iv)) - tag = urlsafe_b64decode(to_bytes(tag)) - - cipher = Cipher(AES(op_key), GCM(iv, tag), backend=default_backend()) - d = cipher.decryptor() - cek = d.update(ek) + d.finalize() - if len(cek) * 8 != enc_alg.CEK_SIZE: - raise ValueError('Invalid "cek" length') - return cek - - -class ECDHAlgorithm(JWEAlgorithm): - EXTRA_HEADERS = ['epk', 'apu', 'apv'] - ALLOWED_KEY_CLS = ECKey - - # https://tools.ietf.org/html/rfc7518#section-4.6 - def __init__(self, key_size=None): - if key_size is None: - self.name = 'ECDH-ES' - self.description = 'ECDH-ES in the Direct Key Agreement mode' - else: - self.name = 'ECDH-ES+A{}KW'.format(key_size) - self.description = ( - 'ECDH-ES using Concat KDF and CEK wrapped ' - 'with A{}KW').format(key_size) - self.key_size = key_size - self.aeskw = AESAlgorithm(key_size) - - def prepare_key(self, raw_data): - if isinstance(raw_data, self.ALLOWED_KEY_CLS): - return raw_data - return ECKey.import_key(raw_data) - - def deliver(self, key, pubkey, headers, bit_size): - # AlgorithmID - if self.key_size is None: - alg_id = _u32be_len_input(headers['enc']) - else: - alg_id = _u32be_len_input(headers['alg']) - - # PartyUInfo - apu_info = _u32be_len_input(headers.get('apu'), True) - - # PartyVInfo - apv_info = _u32be_len_input(headers.get('apv'), True) - - # SuppPubInfo - pub_info = struct.pack('>I', bit_size) - - other_info = alg_id + apu_info + apv_info + pub_info - shared_key = key.exchange_shared_key(pubkey) - ckdf = ConcatKDFHash( - algorithm=hashes.SHA256(), - length=bit_size // 8, - otherinfo=other_info, - backend=default_backend() - ) - return ckdf.derive(shared_key) - - def wrap(self, enc_alg, headers, key): - if self.key_size is None: - bit_size = enc_alg.key_size - else: - bit_size = self.key_size - - epk = key.generate_key(key['crv'], is_private=True) - public_key = key.get_op_key('wrapKey') - dk = self.deliver(epk, public_key, headers, bit_size) - - # REQUIRED_JSON_FIELDS contains only public fields - pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS} - pub_epk['kty'] = epk.kty - h = {'epk': pub_epk} - if self.key_size is None: - return {'ek': b'', 'cek': dk, 'header': h} - - kek = self.aeskw.prepare_key(dk) - rv = self.aeskw.wrap(enc_alg, headers, kek) - rv['header'] = h - return rv - - def unwrap(self, enc_alg, ek, headers, key): - if 'epk' not in headers: - raise ValueError('Missing "epk" in headers') - - if self.key_size is None: - bit_size = enc_alg.key_size - else: - bit_size = self.key_size - - epk = key.import_key(headers['epk']) - public_key = epk.get_op_key('wrapKey') - dk = self.deliver(key, public_key, headers, bit_size) - - if self.key_size is None: - return dk - - kek = self.aeskw.prepare_key(dk) - return self.aeskw.unwrap(enc_alg, ek, headers, kek) - - -def _u32be_len_input(s, base64=False): - if not s: - return b'\x00\x00\x00\x00' - if base64: - s = urlsafe_b64decode(to_bytes(s)) - else: - s = to_bytes(s) - return struct.pack('>I', len(s)) + s - - -JWE_ALG_ALGORITHMS = [ - RSAAlgorithm('RSA1_5', 'RSAES-PKCS1-v1_5', padding.PKCS1v15()), - RSAAlgorithm( - 'RSA-OAEP', 'RSAES OAEP using default parameters', - padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)), - RSAAlgorithm( - 'RSA-OAEP-256', 'RSAES OAEP using SHA-256 and MGF1 with SHA-256', - padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)), - - AESAlgorithm(128), # A128KW - AESAlgorithm(192), # A192KW - AESAlgorithm(256), # A256KW - AESGCMAlgorithm(128), # A128GCMKW - AESGCMAlgorithm(192), # A192GCMKW - AESGCMAlgorithm(256), # A256GCMKW - ECDHAlgorithm(None), # ECDH-ES - ECDHAlgorithm(128), # ECDH-ES+A128KW - ECDHAlgorithm(192), # ECDH-ES+A192KW - ECDHAlgorithm(256), # ECDH-ES+A256KW -] - -# 'PBES2-HS256+A128KW': '', -# 'PBES2-HS384+A192KW': '', -# 'PBES2-HS512+A256KW': '', diff --git a/authlib/jose/rfc7518/_cryptography_backends/_keys.py b/authlib/jose/rfc7518/_cryptography_backends/_keys.py deleted file mode 100644 index 9ca438989..000000000 --- a/authlib/jose/rfc7518/_cryptography_backends/_keys.py +++ /dev/null @@ -1,336 +0,0 @@ -from cryptography.x509 import load_pem_x509_certificate -from cryptography.hazmat.primitives.serialization import ( - load_pem_private_key, load_pem_public_key, load_ssh_public_key, - Encoding, PrivateFormat, PublicFormat, - BestAvailableEncryption, NoEncryption, -) -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPublicKey, RSAPrivateKeyWithSerialization, - RSAPrivateNumbers, RSAPublicNumbers, - rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp -) -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric.ec import ( - EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization, - EllipticCurvePrivateNumbers, EllipticCurvePublicNumbers, - SECP256R1, SECP384R1, SECP521R1, -) -from cryptography.hazmat.backends import default_backend -from authlib.jose.rfc7517 import Key -from authlib.common.encoding import to_bytes -from authlib.common.encoding import base64_to_int, int_to_base64 - - -class RSAKey(Key): - """Key class of the ``RSA`` key type.""" - - kty = 'RSA' - RAW_KEY_CLS = (RSAPublicKey, RSAPrivateKeyWithSerialization) - REQUIRED_JSON_FIELDS = ['e', 'n'] - - def as_pem(self, is_private=False, password=None): - """Export key into PEM format bytes. - - :param is_private: export private key or public key - :param password: encrypt private key with password - :return: bytes - """ - return export_key(self, is_private=is_private, password=password) - - @staticmethod - def dumps_private_key(raw_key): - numbers = raw_key.private_numbers() - return { - 'n': int_to_base64(numbers.public_numbers.n), - 'e': int_to_base64(numbers.public_numbers.e), - 'd': int_to_base64(numbers.d), - 'p': int_to_base64(numbers.p), - 'q': int_to_base64(numbers.q), - 'dp': int_to_base64(numbers.dmp1), - 'dq': int_to_base64(numbers.dmq1), - 'qi': int_to_base64(numbers.iqmp) - } - - @staticmethod - def dumps_public_key(raw_key): - numbers = raw_key.public_numbers() - return { - 'n': int_to_base64(numbers.n), - 'e': int_to_base64(numbers.e) - } - - @staticmethod - def loads_private_key(obj): - if 'oth' in obj: # pragma: no cover - # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 - raise ValueError('"oth" is not supported yet') - - props = ['p', 'q', 'dp', 'dq', 'qi'] - props_found = [prop in obj for prop in props] - any_props_found = any(props_found) - - if any_props_found and not all(props_found): - raise ValueError( - 'RSA key must include all parameters ' - 'if any are present besides d') - - public_numbers = RSAPublicNumbers( - base64_to_int(obj['e']), base64_to_int(obj['n'])) - - if any_props_found: - numbers = RSAPrivateNumbers( - d=base64_to_int(obj['d']), - p=base64_to_int(obj['p']), - q=base64_to_int(obj['q']), - dmp1=base64_to_int(obj['dp']), - dmq1=base64_to_int(obj['dq']), - iqmp=base64_to_int(obj['qi']), - public_numbers=public_numbers) - else: - d = base64_to_int(obj['d']) - p, q = rsa_recover_prime_factors( - public_numbers.n, d, public_numbers.e) - numbers = RSAPrivateNumbers( - d=d, - p=p, - q=q, - dmp1=rsa_crt_dmp1(d, p), - dmq1=rsa_crt_dmq1(d, q), - iqmp=rsa_crt_iqmp(p, q), - public_numbers=public_numbers) - - return numbers.private_key(default_backend()) - - @staticmethod - def loads_public_key(obj): - numbers = RSAPublicNumbers( - base64_to_int(obj['e']), - base64_to_int(obj['n']) - ) - return numbers.public_key(default_backend()) - - @classmethod - def import_key(cls, raw, options=None): - """Import a key from PEM or dict data.""" - return import_key( - cls, raw, - RSAPublicKey, RSAPrivateKeyWithSerialization, - b'ssh-rsa', options - ) - - @classmethod - def generate_key(cls, key_size=2048, options=None, is_private=False): - if key_size < 512: - raise ValueError('key_size must not be less than 512') - if key_size % 8 != 0: - raise ValueError('Invalid key_size for RSAKey') - raw_key = rsa.generate_private_key( - public_exponent=65537, - key_size=key_size, - backend=default_backend(), - ) - if not is_private: - raw_key = raw_key.public_key() - return cls.import_key(raw_key, options=options) - - -class ECKey(Key): - """Key class of the ``EC`` key type.""" - - kty = 'EC' - DSS_CURVES = { - 'P-256': SECP256R1, - 'P-384': SECP384R1, - 'P-521': SECP521R1, - } - CURVES_DSS = { - SECP256R1.name: 'P-256', - SECP384R1.name: 'P-384', - SECP521R1.name: 'P-521', - } - REQUIRED_JSON_FIELDS = ['crv', 'x', 'y'] - RAW_KEY_CLS = (EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization) - - def as_pem(self, is_private=False, password=None): - """Export key into PEM format bytes. - - :param is_private: export private key or public key - :param password: encrypt private key with password - :return: bytes - """ - return export_key(self, is_private=is_private, password=password) - - def exchange_shared_key(self, pubkey): - # # used in ECDHAlgorithm - if isinstance(self.raw_key, EllipticCurvePrivateKeyWithSerialization): - return self.raw_key.exchange(ec.ECDH(), pubkey) - raise ValueError('Invalid key for exchanging shared key') - - @property - def curve_key_size(self): - return self.raw_key.curve.key_size - - @classmethod - def loads_private_key(cls, obj): - curve = cls.DSS_CURVES[obj['crv']]() - public_numbers = EllipticCurvePublicNumbers( - base64_to_int(obj['x']), - base64_to_int(obj['y']), - curve, - ) - private_numbers = EllipticCurvePrivateNumbers( - base64_to_int(obj['d']), - public_numbers - ) - return private_numbers.private_key(default_backend()) - - @classmethod - def loads_public_key(cls, obj): - curve = cls.DSS_CURVES[obj['crv']]() - public_numbers = EllipticCurvePublicNumbers( - base64_to_int(obj['x']), - base64_to_int(obj['y']), - curve, - ) - return public_numbers.public_key(default_backend()) - - @classmethod - def dumps_private_key(cls, raw_key): - numbers = raw_key.private_numbers() - return { - 'crv': cls.CURVES_DSS[raw_key.curve.name], - 'x': int_to_base64(numbers.public_numbers.x), - 'y': int_to_base64(numbers.public_numbers.y), - 'd': int_to_base64(numbers.private_value), - } - - @classmethod - def dumps_public_key(cls, raw_key): - numbers = raw_key.public_numbers() - return { - 'crv': cls.CURVES_DSS[numbers.curve.name], - 'x': int_to_base64(numbers.x), - 'y': int_to_base64(numbers.y) - } - - @classmethod - def import_key(cls, raw, options=None): - """Import a key from PEM or dict data.""" - return import_key( - cls, raw, - EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization, - b'ecdsa-sha2-', options - ) - - @classmethod - def generate_key(cls, crv='P-256', options=None, is_private=False): - if crv not in cls.DSS_CURVES: - raise ValueError('Invalid crv value: "{}"'.format(crv)) - raw_key = ec.generate_private_key( - curve=cls.DSS_CURVES[crv](), - backend=default_backend(), - ) - if not is_private: - raw_key = raw_key.public_key() - return cls.import_key(raw_key, options=options) - - -def load_pem_key(raw, ssh_type=None, key_type=None, password=None): - raw = to_bytes(raw) - - if ssh_type and raw.startswith(ssh_type): - return load_ssh_public_key(raw, backend=default_backend()) - - if key_type == 'public': - return load_pem_public_key(raw, backend=default_backend()) - - if key_type == 'private' or password is not None: - return load_pem_private_key(raw, password=password, backend=default_backend()) - - if b'PUBLIC' in raw: - return load_pem_public_key(raw, backend=default_backend()) - - if b'PRIVATE' in raw: - return load_pem_private_key(raw, password=password, backend=default_backend()) - - if b'CERTIFICATE' in raw: - cert = load_pem_x509_certificate(raw, default_backend()) - return cert.public_key() - - try: - return load_pem_private_key(raw, password=password, backend=default_backend()) - except ValueError: - return load_pem_public_key(raw, backend=default_backend()) - - -def import_key(cls, raw, public_key_cls, private_key_cls, ssh_type=None, options=None): - if isinstance(raw, cls): - if options is not None: - raw.update(options) - return raw - - payload = None - if isinstance(raw, (public_key_cls, private_key_cls)): - raw_key = raw - elif isinstance(raw, dict): - cls.check_required_fields(raw) - payload = raw - if 'd' in payload: - raw_key = cls.loads_private_key(payload) - else: - raw_key = cls.loads_public_key(payload) - else: - if options is not None: - password = options.get('password') - else: - password = None - raw_key = load_pem_key(raw, ssh_type, password=password) - - if isinstance(raw_key, private_key_cls): - if payload is None: - payload = cls.dumps_private_key(raw_key) - key_type = 'private' - elif isinstance(raw_key, public_key_cls): - if payload is None: - payload = cls.dumps_public_key(raw_key) - key_type = 'public' - else: - raise ValueError('Invalid data for importing key') - - obj = cls(payload) - obj.raw_key = raw_key - obj.key_type = key_type - return obj - - -def export_key(key, encoding=None, is_private=False, password=None): - if encoding is None or encoding == 'PEM': - encoding = Encoding.PEM - elif encoding == 'DER': - encoding = Encoding.DER - else: - raise ValueError('Invalid encoding: {!r}'.format(encoding)) - - if is_private: - if key.key_type == 'private': - if password is None: - encryption_algorithm = NoEncryption() - else: - encryption_algorithm = BestAvailableEncryption(to_bytes(password)) - return key.raw_key.private_bytes( - encoding=encoding, - format=PrivateFormat.PKCS8, - encryption_algorithm=encryption_algorithm, - ) - raise ValueError('This is a public key') - - if key.key_type == 'private': - raw_key = key.raw_key.public_key() - else: - raw_key = key.raw_key - - return raw_key.public_bytes( - encoding=encoding, - format=PublicFormat.SubjectPublicKeyInfo, - ) diff --git a/authlib/jose/rfc7518/ec_key.py b/authlib/jose/rfc7518/ec_key.py new file mode 100644 index 000000000..82ec6a4b4 --- /dev/null +++ b/authlib/jose/rfc7518/ec_key.py @@ -0,0 +1,108 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1 +from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1 +from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1 +from cryptography.hazmat.primitives.asymmetric.ec import SECP521R1 +from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePrivateKeyWithSerialization, +) +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateNumbers +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers + +from authlib.common.encoding import base64_to_int +from authlib.common.encoding import int_to_base64 + +from ..rfc7517 import AsymmetricKey + + +class ECKey(AsymmetricKey): + """Key class of the ``EC`` key type.""" + + kty = "EC" + DSS_CURVES = { + "P-256": SECP256R1, + "P-384": SECP384R1, + "P-521": SECP521R1, + # https://tools.ietf.org/html/rfc8812#section-3.1 + "secp256k1": SECP256K1, + } + CURVES_DSS = { + SECP256R1.name: "P-256", + SECP384R1.name: "P-384", + SECP521R1.name: "P-521", + SECP256K1.name: "secp256k1", + } + REQUIRED_JSON_FIELDS = ["crv", "x", "y"] + + PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS + PRIVATE_KEY_FIELDS = ["crv", "d", "x", "y"] + + PUBLIC_KEY_CLS = EllipticCurvePublicKey + PRIVATE_KEY_CLS = EllipticCurvePrivateKeyWithSerialization + SSH_PUBLIC_PREFIX = b"ecdsa-sha2-" + + def exchange_shared_key(self, pubkey): + # # used in ECDHESAlgorithm + private_key = self.get_private_key() + if private_key: + return private_key.exchange(ec.ECDH(), pubkey) + raise ValueError("Invalid key for exchanging shared key") + + @property + def curve_key_size(self): + raw_key = self.get_private_key() + if not raw_key: + raw_key = self.public_key + return raw_key.curve.key_size + + def load_private_key(self): + curve = self.DSS_CURVES[self._dict_data["crv"]]() + public_numbers = EllipticCurvePublicNumbers( + base64_to_int(self._dict_data["x"]), + base64_to_int(self._dict_data["y"]), + curve, + ) + private_numbers = EllipticCurvePrivateNumbers( + base64_to_int(self.tokens["d"]), public_numbers + ) + return private_numbers.private_key(default_backend()) + + def load_public_key(self): + curve = self.DSS_CURVES[self._dict_data["crv"]]() + public_numbers = EllipticCurvePublicNumbers( + base64_to_int(self._dict_data["x"]), + base64_to_int(self._dict_data["y"]), + curve, + ) + return public_numbers.public_key(default_backend()) + + def dumps_private_key(self): + numbers = self.private_key.private_numbers() + return { + "crv": self.CURVES_DSS[self.private_key.curve.name], + "x": int_to_base64(numbers.public_numbers.x), + "y": int_to_base64(numbers.public_numbers.y), + "d": int_to_base64(numbers.private_value), + } + + def dumps_public_key(self): + numbers = self.public_key.public_numbers() + return { + "crv": self.CURVES_DSS[numbers.curve.name], + "x": int_to_base64(numbers.x), + "y": int_to_base64(numbers.y), + } + + @classmethod + def generate_key(cls, crv="P-256", options=None, is_private=False) -> "ECKey": + if crv not in cls.DSS_CURVES: + raise ValueError(f'Invalid crv value: "{crv}"') + raw_key = ec.generate_private_key( + curve=cls.DSS_CURVES[crv](), + backend=default_backend(), + ) + if not is_private: + raw_key = raw_key.public_key() + return cls.import_key(raw_key, options=options) diff --git a/authlib/jose/rfc7518/jwe_algorithms.py b/authlib/jose/rfc7518/jwe_algorithms.py deleted file mode 100644 index 1e5dc961f..000000000 --- a/authlib/jose/rfc7518/jwe_algorithms.py +++ /dev/null @@ -1,50 +0,0 @@ -import zlib -from .oct_key import OctKey -from ._cryptography_backends import JWE_ALG_ALGORITHMS, JWE_ENC_ALGORITHMS -from ..rfc7516 import JWEAlgorithm, JWEZipAlgorithm, JsonWebEncryption - - -class DirectAlgorithm(JWEAlgorithm): - name = 'dir' - description = 'Direct use of a shared symmetric key' - - def prepare_key(self, raw_data): - return OctKey.import_key(raw_data) - - def wrap(self, enc_alg, headers, key): - cek = key.get_op_key('encrypt') - if len(cek) * 8 != enc_alg.CEK_SIZE: - raise ValueError('Invalid "cek" length') - return {'ek': b'', 'cek': cek} - - def unwrap(self, enc_alg, ek, headers, key): - cek = key.get_op_key('decrypt') - if len(cek) * 8 != enc_alg.CEK_SIZE: - raise ValueError('Invalid "cek" length') - return cek - - -class DeflateZipAlgorithm(JWEZipAlgorithm): - name = 'DEF' - description = 'DEFLATE' - - def compress(self, s): - """Compress bytes data with DEFLATE algorithm.""" - data = zlib.compress(s) - # drop gzip headers and tail - return data[2:-4] - - def decompress(self, s): - """Decompress DEFLATE bytes data.""" - return zlib.decompress(s, -zlib.MAX_WBITS) - - -def register_jwe_rfc7518(): - JsonWebEncryption.register_algorithm(DirectAlgorithm()) - JsonWebEncryption.register_algorithm(DeflateZipAlgorithm()) - - for algorithm in JWE_ALG_ALGORITHMS: - JsonWebEncryption.register_algorithm(algorithm) - - for algorithm in JWE_ENC_ALGORITHMS: - JsonWebEncryption.register_algorithm(algorithm) diff --git a/authlib/jose/rfc7518/jwe_algs.py b/authlib/jose/rfc7518/jwe_algs.py new file mode 100644 index 000000000..778cc478c --- /dev/null +++ b/authlib/jose/rfc7518/jwe_algs.py @@ -0,0 +1,350 @@ +import secrets +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptography.hazmat.primitives.ciphers.modes import GCM +from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash +from cryptography.hazmat.primitives.keywrap import aes_key_unwrap +from cryptography.hazmat.primitives.keywrap import aes_key_wrap + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_native +from authlib.common.encoding import urlsafe_b64decode +from authlib.common.encoding import urlsafe_b64encode +from authlib.jose.rfc7516 import JWEAlgorithm + +from .ec_key import ECKey +from .oct_key import OctKey +from .rsa_key import RSAKey + + +class DirectAlgorithm(JWEAlgorithm): + name = "dir" + description = "Direct use of a shared symmetric key" + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + return {} + + def wrap(self, enc_alg, headers, key, preset=None): + cek = key.get_op_key("encrypt") + if len(cek) * 8 != enc_alg.CEK_SIZE: + raise ValueError('Invalid "cek" length') + return {"ek": b"", "cek": cek} + + def unwrap(self, enc_alg, ek, headers, key): + cek = key.get_op_key("decrypt") + if len(cek) * 8 != enc_alg.CEK_SIZE: + cek = secrets.token_bytes(enc_alg.CEK_SIZE // 8) + return cek + + +class RSAAlgorithm(JWEAlgorithm): + #: A key of size 2048 bits or larger MUST be used with these algorithms + #: RSA1_5, RSA-OAEP, RSA-OAEP-256 + key_size = 2048 + + def __init__(self, name, description, pad_fn): + self.name = name + self.deprecated = name == "RSA1_5" + self.description = description + self.padding = pad_fn + + def prepare_key(self, raw_data): + return RSAKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + cek = enc_alg.generate_cek() + return {"cek": cek} + + def wrap(self, enc_alg, headers, key, preset=None): + if preset and "cek" in preset: + cek = preset["cek"] + else: + cek = enc_alg.generate_cek() + + op_key = key.get_op_key("wrapKey") + if op_key.key_size < self.key_size: + raise ValueError("A key of size 2048 bits or larger MUST be used") + ek = op_key.encrypt(cek, self.padding) + return {"ek": ek, "cek": cek} + + def unwrap(self, enc_alg, ek, headers, key): + op_key = key.get_op_key("unwrapKey") + cek = op_key.decrypt(ek, self.padding) + if len(cek) * 8 != enc_alg.CEK_SIZE: + cek = secrets.token_bytes(enc_alg.CEK_SIZE // 8) + return cek + + +class AESAlgorithm(JWEAlgorithm): + def __init__(self, key_size): + self.name = f"A{key_size}KW" + self.description = f"AES Key Wrap using {key_size}-bit key" + self.key_size = key_size + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + cek = enc_alg.generate_cek() + return {"cek": cek} + + def _check_key(self, key): + if len(key) * 8 != self.key_size: + raise ValueError(f"A key of size {self.key_size} bits is required.") + + def wrap_cek(self, cek, key): + op_key = key.get_op_key("wrapKey") + self._check_key(op_key) + ek = aes_key_wrap(op_key, cek, default_backend()) + return {"ek": ek, "cek": cek} + + def wrap(self, enc_alg, headers, key, preset=None): + if preset and "cek" in preset: + cek = preset["cek"] + else: + cek = enc_alg.generate_cek() + return self.wrap_cek(cek, key) + + def unwrap(self, enc_alg, ek, headers, key): + op_key = key.get_op_key("unwrapKey") + self._check_key(op_key) + cek = aes_key_unwrap(op_key, ek, default_backend()) + if len(cek) * 8 != enc_alg.CEK_SIZE: + cek = secrets.token_bytes(enc_alg.CEK_SIZE // 8) + return cek + + +class AESGCMAlgorithm(JWEAlgorithm): + EXTRA_HEADERS = frozenset(["iv", "tag"]) + + def __init__(self, key_size): + self.name = f"A{key_size}GCMKW" + self.description = f"Key wrapping with AES GCM using {key_size}-bit key" + self.key_size = key_size + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + cek = enc_alg.generate_cek() + return {"cek": cek} + + def _check_key(self, key): + if len(key) * 8 != self.key_size: + raise ValueError(f"A key of size {self.key_size} bits is required.") + + def wrap(self, enc_alg, headers, key, preset=None): + if preset and "cek" in preset: + cek = preset["cek"] + else: + cek = enc_alg.generate_cek() + + op_key = key.get_op_key("wrapKey") + self._check_key(op_key) + + #: https://tools.ietf.org/html/rfc7518#section-4.7.1.1 + #: The "iv" (initialization vector) Header Parameter value is the + #: base64url-encoded representation of the 96-bit IV value + iv_size = 96 + iv = secrets.token_bytes(iv_size // 8) + + cipher = Cipher(AES(op_key), GCM(iv), backend=default_backend()) + enc = cipher.encryptor() + ek = enc.update(cek) + enc.finalize() + + h = { + "iv": to_native(urlsafe_b64encode(iv)), + "tag": to_native(urlsafe_b64encode(enc.tag)), + } + return {"ek": ek, "cek": cek, "header": h} + + def unwrap(self, enc_alg, ek, headers, key): + op_key = key.get_op_key("unwrapKey") + self._check_key(op_key) + + iv = headers.get("iv") + if not iv: + raise ValueError('Missing "iv" in headers') + + tag = headers.get("tag") + if not tag: + raise ValueError('Missing "tag" in headers') + + iv = urlsafe_b64decode(to_bytes(iv)) + tag = urlsafe_b64decode(to_bytes(tag)) + + cipher = Cipher(AES(op_key), GCM(iv, tag), backend=default_backend()) + d = cipher.decryptor() + cek = d.update(ek) + d.finalize() + if len(cek) * 8 != enc_alg.CEK_SIZE: + cek = secrets.token_bytes(enc_alg.CEK_SIZE // 8) + return cek + + +class ECDHESAlgorithm(JWEAlgorithm): + EXTRA_HEADERS = ["epk", "apu", "apv"] + ALLOWED_KEY_CLS = ECKey + + # https://tools.ietf.org/html/rfc7518#section-4.6 + def __init__(self, key_size=None): + if key_size is None: + self.name = "ECDH-ES" + self.description = "ECDH-ES in the Direct Key Agreement mode" + else: + self.name = f"ECDH-ES+A{key_size}KW" + self.description = ( + f"ECDH-ES using Concat KDF and CEK wrapped with A{key_size}KW" + ) + self.key_size = key_size + self.aeskw = AESAlgorithm(key_size) + + def prepare_key(self, raw_data): + if isinstance(raw_data, self.ALLOWED_KEY_CLS): + return raw_data + return ECKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + preset = {"epk": epk, "header": h} + if self.key_size is not None: + cek = enc_alg.generate_cek() + preset["cek"] = cek + return preset + + def compute_fixed_info(self, headers, bit_size): + # AlgorithmID + if self.key_size is None: + alg_id = u32be_len_input(headers["enc"]) + else: + alg_id = u32be_len_input(headers["alg"]) + + # PartyUInfo + apu_info = u32be_len_input(headers.get("apu"), True) + + # PartyVInfo + apv_info = u32be_len_input(headers.get("apv"), True) + + # SuppPubInfo + pub_info = struct.pack(">I", bit_size) + + return alg_id + apu_info + apv_info + pub_info + + def compute_derived_key(self, shared_key, fixed_info, bit_size): + ckdf = ConcatKDFHash( + algorithm=hashes.SHA256(), + length=bit_size // 8, + otherinfo=fixed_info, + backend=default_backend(), + ) + return ckdf.derive(shared_key) + + def deliver(self, key, pubkey, headers, bit_size): + shared_key = key.exchange_shared_key(pubkey) + fixed_info = self.compute_fixed_info(headers, bit_size) + return self.compute_derived_key(shared_key, fixed_info, bit_size) + + def _generate_ephemeral_key(self, key): + return key.generate_key(key["crv"], is_private=True) + + def _prepare_headers(self, epk): + # REQUIRED_JSON_FIELDS contains only public fields + pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS} + pub_epk["kty"] = epk.kty + return {"epk": pub_epk} + + def wrap(self, enc_alg, headers, key, preset=None): + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + if preset and "epk" in preset: + epk = preset["epk"] + h = {} + else: + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + + public_key = key.get_op_key("wrapKey") + dk = self.deliver(epk, public_key, headers, bit_size) + + if self.key_size is None: + return {"ek": b"", "cek": dk, "header": h} + + if preset and "cek" in preset: + preset_for_kw = {"cek": preset["cek"]} + else: + preset_for_kw = None + + kek = self.aeskw.prepare_key(dk) + rv = self.aeskw.wrap(enc_alg, headers, kek, preset_for_kw) + rv["header"] = h + return rv + + def unwrap(self, enc_alg, ek, headers, key): + if "epk" not in headers: + raise ValueError('Missing "epk" in headers') + + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + epk = key.import_key(headers["epk"]) + public_key = epk.get_op_key("wrapKey") + dk = self.deliver(key, public_key, headers, bit_size) + + if self.key_size is None: + return dk + + kek = self.aeskw.prepare_key(dk) + return self.aeskw.unwrap(enc_alg, ek, headers, kek) + + +def u32be_len_input(s, base64=False): + if not s: + return b"\x00\x00\x00\x00" + if base64: + s = urlsafe_b64decode(to_bytes(s)) + else: + s = to_bytes(s) + return struct.pack(">I", len(s)) + s + + +JWE_ALG_ALGORITHMS = [ + DirectAlgorithm(), # dir + RSAAlgorithm("RSA1_5", "RSAES-PKCS1-v1_5", padding.PKCS1v15()), + RSAAlgorithm( + "RSA-OAEP", + "RSAES OAEP using default parameters", + padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None), + ), + RSAAlgorithm( + "RSA-OAEP-256", + "RSAES OAEP using SHA-256 and MGF1 with SHA-256", + padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None), + ), + AESAlgorithm(128), # A128KW + AESAlgorithm(192), # A192KW + AESAlgorithm(256), # A256KW + AESGCMAlgorithm(128), # A128GCMKW + AESGCMAlgorithm(192), # A192GCMKW + AESGCMAlgorithm(256), # A256GCMKW + ECDHESAlgorithm(None), # ECDH-ES + ECDHESAlgorithm(128), # ECDH-ES+A128KW + ECDHESAlgorithm(192), # ECDH-ES+A192KW + ECDHESAlgorithm(256), # ECDH-ES+A256KW +] + +# 'PBES2-HS256+A128KW': '', +# 'PBES2-HS384+A192KW': '', +# 'PBES2-HS512+A256KW': '', diff --git a/authlib/jose/rfc7518/_cryptography_backends/_jwe_enc.py b/authlib/jose/rfc7518/jwe_encs.py similarity index 81% rename from authlib/jose/rfc7518/_cryptography_backends/_jwe_enc.py rename to authlib/jose/rfc7518/jwe_encs.py index f955a7c51..38246131b 100644 --- a/authlib/jose/rfc7518/_cryptography_backends/_jwe_enc.py +++ b/authlib/jose/rfc7518/jwe_encs.py @@ -1,22 +1,25 @@ -""" - authlib.jose.rfc7518 - ~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.rfc7518. +~~~~~~~~~~~~~~~~~~~~ - Cryptographic Algorithms for Cryptographic Algorithms for Content - Encryption per `Section 5`_. +Cryptographic Algorithms for Cryptographic Algorithms for Content +Encryption per `Section 5`_. - .. _`Section 5`: https://tools.ietf.org/html/rfc7518#section-5 +.. _`Section 5`: https://tools.ietf.org/html/rfc7518#section-5 """ -import hmac + import hashlib +import hmac + +from cryptography.exceptions import InvalidTag from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers.algorithms import AES -from cryptography.hazmat.primitives.ciphers.modes import GCM, CBC +from cryptography.hazmat.primitives.ciphers.modes import CBC +from cryptography.hazmat.primitives.ciphers.modes import GCM from cryptography.hazmat.primitives.padding import PKCS7 -from cryptography.exceptions import InvalidTag -from authlib.jose.rfc7516 import JWEEncAlgorithm -from ..util import encode_int + +from ..rfc7516 import JWEEncAlgorithm +from .util import encode_int class CBCHS2EncAlgorithm(JWEEncAlgorithm): @@ -25,8 +28,8 @@ class CBCHS2EncAlgorithm(JWEEncAlgorithm): IV_SIZE = 128 def __init__(self, key_size, hash_type): - self.name = 'A{}CBC-HS{}'.format(key_size, hash_type) - tpl = 'AES_{}_CBC_HMAC_SHA_{} authenticated encryption algorithm' + self.name = f"A{key_size}CBC-HS{hash_type}" + tpl = "AES_{}_CBC_HMAC_SHA_{} authenticated encryption algorithm" self.description = tpl.format(key_size, hash_type) # bit length @@ -35,13 +38,13 @@ def __init__(self, key_size, hash_type): self.key_len = key_size // 8 self.CEK_SIZE = key_size * 2 - self.hash_alg = getattr(hashlib, 'sha{}'.format(hash_type)) + self.hash_alg = getattr(hashlib, f"sha{hash_type}") def _hmac(self, ciphertext, aad, iv, key): al = encode_int(len(aad) * 8, 64) msg = aad + iv + ciphertext + al d = hmac.new(key, msg, self.hash_alg).digest() - return d[:self.key_len] + return d[: self.key_len] def encrypt(self, msg, aad, iv, key): """Key Encryption with AES_CBC_HMAC_SHA2. @@ -53,8 +56,8 @@ def encrypt(self, msg, aad, iv, key): :return: (ciphertext, iv, tag) """ self.check_iv(iv) - hkey = key[:self.key_len] - ekey = key[self.key_len:] + hkey = key[: self.key_len] + ekey = key[self.key_len :] pad = PKCS7(AES.block_size).padder() padded_data = pad.update(msg) + pad.finalize() @@ -76,8 +79,8 @@ def decrypt(self, ciphertext, aad, iv, tag, key): :return: message """ self.check_iv(iv) - hkey = key[:self.key_len] - dkey = key[self.key_len:] + hkey = key[: self.key_len] + dkey = key[self.key_len :] _tag = self._hmac(ciphertext, aad, iv, hkey) if not hmac.compare_digest(_tag, tag): @@ -96,13 +99,13 @@ class GCMEncAlgorithm(JWEEncAlgorithm): IV_SIZE = 96 def __init__(self, key_size): - self.name = 'A{}GCM'.format(key_size) - self.description = 'AES GCM using {}-bit key'.format(key_size) + self.name = f"A{key_size}GCM" + self.description = f"AES GCM using {key_size}-bit key" self.key_size = key_size self.CEK_SIZE = key_size def encrypt(self, msg, aad, iv, key): - """Key Encryption with AES GCM + """Key Encryption with AES GCM. :param msg: text to be encrypt in bytes :param aad: additional authenticated data in bytes @@ -118,7 +121,7 @@ def encrypt(self, msg, aad, iv, key): return ciphertext, enc.tag def decrypt(self, ciphertext, aad, iv, tag, key): - """Key Decryption with AES GCM + """Key Decryption with AES GCM. :param ciphertext: ciphertext in bytes :param aad: additional authenticated data in bytes diff --git a/authlib/jose/rfc7518/jwe_zips.py b/authlib/jose/rfc7518/jwe_zips.py new file mode 100644 index 000000000..70b1c5cf4 --- /dev/null +++ b/authlib/jose/rfc7518/jwe_zips.py @@ -0,0 +1,34 @@ +import zlib + +from ..rfc7516 import JsonWebEncryption +from ..rfc7516 import JWEZipAlgorithm + +GZIP_HEAD = bytes([120, 156]) +MAX_SIZE = 250 * 1024 + + +class DeflateZipAlgorithm(JWEZipAlgorithm): + name = "DEF" + description = "DEFLATE" + + def compress(self, s: bytes) -> bytes: + """Compress bytes data with DEFLATE algorithm.""" + data = zlib.compress(s) + # https://datatracker.ietf.org/doc/html/rfc1951 + # since DEF is always gzip, we can drop gzip headers and tail + return data[2:-4] + + def decompress(self, s: bytes) -> bytes: + """Decompress DEFLATE bytes data.""" + if s.startswith(GZIP_HEAD): + decompressor = zlib.decompressobj() + else: + decompressor = zlib.decompressobj(-zlib.MAX_WBITS) + value = decompressor.decompress(s, MAX_SIZE) + if decompressor.unconsumed_tail: + raise ValueError(f"Decompressed string exceeds {MAX_SIZE} bytes") + return value + + +def register_jwe_rfc7518(): + JsonWebEncryption.register_algorithm(DeflateZipAlgorithm()) diff --git a/authlib/jose/rfc7518/jws_algorithms.py b/authlib/jose/rfc7518/jws_algorithms.py deleted file mode 100644 index 637298554..000000000 --- a/authlib/jose/rfc7518/jws_algorithms.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- -""" - authlib.jose.rfc7518.jws_algorithms - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - "alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_. - - .. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3 -""" - -import hmac -import hashlib -from .oct_key import OctKey -from ._cryptography_backends import JWS_ALGORITHMS -from ..rfc7515 import JWSAlgorithm, JsonWebSignature - - -class NoneAlgorithm(JWSAlgorithm): - name = 'none' - description = 'No digital signature or MAC performed' - - def prepare_key(self, raw_data): - return None - - def sign(self, msg, key): - return b'' - - def verify(self, msg, sig, key): - return False - - -class HMACAlgorithm(JWSAlgorithm): - """HMAC using SHA algorithms for JWS. Available algorithms: - - - HS256: HMAC using SHA-256 - - HS384: HMAC using SHA-384 - - HS512: HMAC using SHA-512 - """ - SHA256 = hashlib.sha256 - SHA384 = hashlib.sha384 - SHA512 = hashlib.sha512 - - def __init__(self, sha_type): - self.name = 'HS{}'.format(sha_type) - self.description = 'HMAC using SHA-{}'.format(sha_type) - self.hash_alg = getattr(self, 'SHA{}'.format(sha_type)) - - def prepare_key(self, raw_data): - return OctKey.import_key(raw_data) - - def sign(self, msg, key): - # it is faster than the one in cryptography - op_key = key.get_op_key('sign') - return hmac.new(op_key, msg, self.hash_alg).digest() - - def verify(self, msg, sig, key): - op_key = key.get_op_key('verify') - v_sig = hmac.new(op_key, msg, self.hash_alg).digest() - return hmac.compare_digest(sig, v_sig) - - -def register_jws_rfc7518(): - JsonWebSignature.register_algorithm(NoneAlgorithm()) - JsonWebSignature.register_algorithm(HMACAlgorithm(256)) - JsonWebSignature.register_algorithm(HMACAlgorithm(384)) - JsonWebSignature.register_algorithm(HMACAlgorithm(512)) - for algorithm in JWS_ALGORITHMS: - JsonWebSignature.register_algorithm(algorithm) diff --git a/authlib/jose/rfc7518/_cryptography_backends/_jws.py b/authlib/jose/rfc7518/jws_algs.py similarity index 50% rename from authlib/jose/rfc7518/_cryptography_backends/_jws.py rename to authlib/jose/rfc7518/jws_algs.py index 9caee9664..c9e95ec56 100644 --- a/authlib/jose/rfc7518/_cryptography_backends/_jws.py +++ b/authlib/jose/rfc7518/jws_algs.py @@ -1,23 +1,73 @@ -# -*- coding: utf-8 -*- -""" - authlib.jose.rfc7518 - ~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.rfc7518. +~~~~~~~~~~~~~~~~~~~~ - "alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_. +"alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_. - .. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3 +.. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3 """ +import hashlib +import hmac + +from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric.utils import ( - decode_dss_signature, encode_dss_signature -) -from cryptography.hazmat.primitives.asymmetric.ec import ECDSA from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.exceptions import InvalidSignature -from authlib.jose.rfc7515 import JWSAlgorithm -from ._keys import RSAKey, ECKey -from ..util import encode_int, decode_int +from cryptography.hazmat.primitives.asymmetric.ec import ECDSA +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature + +from ..rfc7515 import JWSAlgorithm +from .ec_key import ECKey +from .oct_key import OctKey +from .rsa_key import RSAKey +from .util import decode_int +from .util import encode_int + + +class NoneAlgorithm(JWSAlgorithm): + name = "none" + description = "No digital signature or MAC performed" + deprecated = True + + def prepare_key(self, raw_data): + return None + + def sign(self, msg, key): + return b"" + + def verify(self, msg, sig, key): + return sig == b"" + + +class HMACAlgorithm(JWSAlgorithm): + """HMAC using SHA algorithms for JWS. Available algorithms: + + - HS256: HMAC using SHA-256 + - HS384: HMAC using SHA-384 + - HS512: HMAC using SHA-512 + """ + + SHA256 = hashlib.sha256 + SHA384 = hashlib.sha384 + SHA512 = hashlib.sha512 + + def __init__(self, sha_type): + self.name = f"HS{sha_type}" + self.description = f"HMAC using SHA-{sha_type}" + self.hash_alg = getattr(self, f"SHA{sha_type}") + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def sign(self, msg, key): + # it is faster than the one in cryptography + op_key = key.get_op_key("sign") + return hmac.new(op_key, msg, self.hash_alg).digest() + + def verify(self, msg, sig, key): + op_key = key.get_op_key("verify") + v_sig = hmac.new(op_key, msg, self.hash_alg).digest() + return hmac.compare_digest(sig, v_sig) class RSAAlgorithm(JWSAlgorithm): @@ -27,25 +77,26 @@ class RSAAlgorithm(JWSAlgorithm): - RS384: RSASSA-PKCS1-v1_5 using SHA-384 - RS512: RSASSA-PKCS1-v1_5 using SHA-512 """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 def __init__(self, sha_type): - self.name = 'RS{}'.format(sha_type) - self.description = 'RSASSA-PKCS1-v1_5 using SHA-{}'.format(sha_type) - self.hash_alg = getattr(self, 'SHA{}'.format(sha_type)) + self.name = f"RS{sha_type}" + self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}" + self.hash_alg = getattr(self, f"SHA{sha_type}") self.padding = padding.PKCS1v15() def prepare_key(self, raw_data): return RSAKey.import_key(raw_data) def sign(self, msg, key): - op_key = key.get_op_key('sign') + op_key = key.get_op_key("sign") return op_key.sign(msg, self.padding, self.hash_alg()) def verify(self, msg, sig, key): - op_key = key.get_op_key('verify') + op_key = key.get_op_key("verify") try: op_key.verify(sig, msg, self.padding, self.hash_alg()) return True @@ -60,20 +111,27 @@ class ECAlgorithm(JWSAlgorithm): - ES384: ECDSA using P-384 and SHA-384 - ES512: ECDSA using P-521 and SHA-512 """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 - def __init__(self, sha_type): - self.name = 'ES{}'.format(sha_type) - self.description = 'ECDSA using P-{} and SHA-{}'.format(sha_type, sha_type) - self.hash_alg = getattr(self, 'SHA{}'.format(sha_type)) + def __init__(self, name, curve, sha_type): + self.name = name + self.curve = curve + self.description = f"ECDSA using {self.curve} and SHA-{sha_type}" + self.hash_alg = getattr(self, f"SHA{sha_type}") def prepare_key(self, raw_data): - return ECKey.import_key(raw_data) + key = ECKey.import_key(raw_data) + if key["crv"] != self.curve: + raise ValueError( + f'Key for "{self.name}" not supported, only "{self.curve}" allowed' + ) + return key def sign(self, msg, key): - op_key = key.get_op_key('sign') + op_key = key.get_op_key("sign") der_sig = op_key.sign(msg, ECDSA(self.hash_alg())) r, s = decode_dss_signature(der_sig) size = key.curve_key_size @@ -91,7 +149,7 @@ def verify(self, msg, sig, key): der_sig = encode_dss_signature(r, s) try: - op_key = key.get_op_key('verify') + op_key = key.get_op_key("verify") op_key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: @@ -105,41 +163,41 @@ class RSAPSSAlgorithm(JWSAlgorithm): - PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384 - PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512 """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 def __init__(self, sha_type): - self.name = 'PS{}'.format(sha_type) - tpl = 'RSASSA-PSS using SHA-{} and MGF1 with SHA-{}' + self.name = f"PS{sha_type}" + tpl = "RSASSA-PSS using SHA-{} and MGF1 with SHA-{}" self.description = tpl.format(sha_type, sha_type) - self.hash_alg = getattr(self, 'SHA{}'.format(sha_type)) + self.hash_alg = getattr(self, f"SHA{sha_type}") def prepare_key(self, raw_data): return RSAKey.import_key(raw_data) def sign(self, msg, key): - op_key = key.get_op_key('sign') + op_key = key.get_op_key("sign") return op_key.sign( msg, padding.PSS( - mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size + mgf=padding.MGF1(self.hash_alg()), salt_length=self.hash_alg.digest_size ), - self.hash_alg() + self.hash_alg(), ) def verify(self, msg, sig, key): - op_key = key.get_op_key('verify') + op_key = key.get_op_key("verify") try: op_key.verify( sig, msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size + salt_length=self.hash_alg.digest_size, ), - self.hash_alg() + self.hash_alg(), ) return True except InvalidSignature: @@ -147,12 +205,17 @@ def verify(self, msg, sig, key): JWS_ALGORITHMS = [ + NoneAlgorithm(), # none + HMACAlgorithm(256), # HS256 + HMACAlgorithm(384), # HS384 + HMACAlgorithm(512), # HS512 RSAAlgorithm(256), # RS256 RSAAlgorithm(384), # RS384 RSAAlgorithm(512), # RS512 - ECAlgorithm(256), # ES256 - ECAlgorithm(384), # ES384 - ECAlgorithm(512), # ES512 + ECAlgorithm("ES256", "P-256", 256), + ECAlgorithm("ES384", "P-384", 384), + ECAlgorithm("ES512", "P-521", 512), + ECAlgorithm("ES256K", "secp256k1", 256), # defined in RFC8812 RSAPSSAlgorithm(256), # PS256 RSAPSSAlgorithm(384), # PS384 RSAPSSAlgorithm(512), # PS512 diff --git a/authlib/jose/rfc7518/oct_key.py b/authlib/jose/rfc7518/oct_key.py index a095ada43..6888c4900 100644 --- a/authlib/jose/rfc7518/oct_key.py +++ b/authlib/jose/rfc7518/oct_key.py @@ -1,48 +1,96 @@ -from authlib.common.encoding import ( - to_bytes, to_unicode, - urlsafe_b64encode, urlsafe_b64decode, +import secrets + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64decode +from authlib.common.encoding import urlsafe_b64encode + +from ..rfc7517 import Key + +POSSIBLE_UNSAFE_KEYS = ( + b"-----BEGIN ", + b"---- BEGIN ", + b"ssh-rsa ", + b"ssh-dss ", + b"ssh-ed25519 ", + b"ecdsa-sha2-", ) -from authlib.common.security import generate_token -from authlib.jose.rfc7517 import Key class OctKey(Key): """Key class of the ``oct`` key type.""" - kty = 'oct' - REQUIRED_JSON_FIELDS = ['k'] + kty = "oct" + REQUIRED_JSON_FIELDS = ["k"] + + def __init__(self, raw_key=None, options=None): + super().__init__(options) + self.raw_key = raw_key - def get_op_key(self, key_op): - self.check_key_op(key_op) + @property + def public_only(self): + return False + + def get_op_key(self, operation): + """Get the raw key for the given key_op. This method will also + check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :return: raw key + """ + self.check_key_op(operation) + if not self.raw_key: + self.load_raw_key() return self.raw_key + def load_raw_key(self): + self.raw_key = urlsafe_b64decode(to_bytes(self.tokens["k"])) + + def load_dict_key(self): + k = to_unicode(urlsafe_b64encode(self.raw_key)) + self._dict_data = {"kty": self.kty, "k": k} + + def as_dict(self, is_private=False, **params): + tokens = self.tokens + if "kid" not in tokens: + tokens["kid"] = self.thumbprint() + + tokens.update(params) + return tokens + + @classmethod + def validate_raw_key(cls, key): + return isinstance(key, bytes) + @classmethod def import_key(cls, raw, options=None): """Import a key from bytes, string, or dict data.""" + if isinstance(raw, cls): + if options is not None: + raw.options.update(options) + return raw + if isinstance(raw, dict): cls.check_required_fields(raw) - payload = raw - raw_key = urlsafe_b64decode(to_bytes(payload['k'])) + key = cls(options=options) + key._dict_data = raw else: raw_key = to_bytes(raw) - k = to_unicode(urlsafe_b64encode(raw_key)) - payload = {'k': k} - if options is not None: - payload.update(options) + # security check + if raw_key.startswith(POSSIBLE_UNSAFE_KEYS): + raise ValueError("This key may not be safe to import") - obj = cls(payload) - obj.raw_key = raw_key - obj.key_type = 'secret' - return obj + key = cls(raw_key=raw_key, options=options) + return key @classmethod - def generate_key(cls, key_size=256, options=None, is_private=False): + def generate_key(cls, key_size=256, options=None, is_private=True): """Generate a ``OctKey`` with the given bit size.""" if not is_private: - raise ValueError('oct key can not be generated as public') + raise ValueError("oct key can not be generated as public") if key_size % 8 != 0: - raise ValueError('Invalid bit size for oct key') + raise ValueError("Invalid bit size for oct key") - return cls.import_key(generate_token(key_size // 8), options) + return cls.import_key(secrets.token_bytes(int(key_size / 8)), options) diff --git a/authlib/jose/rfc7518/rsa_key.py b/authlib/jose/rfc7518/rsa_key.py new file mode 100644 index 000000000..6f6db48c7 --- /dev/null +++ b/authlib/jose/rfc7518/rsa_key.py @@ -0,0 +1,127 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKeyWithSerialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers +from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmp1 +from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmq1 +from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_iqmp +from cryptography.hazmat.primitives.asymmetric.rsa import rsa_recover_prime_factors + +from authlib.common.encoding import base64_to_int +from authlib.common.encoding import int_to_base64 + +from ..rfc7517 import AsymmetricKey + + +class RSAKey(AsymmetricKey): + """Key class of the ``RSA`` key type.""" + + kty = "RSA" + PUBLIC_KEY_CLS = RSAPublicKey + PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization + + PUBLIC_KEY_FIELDS = ["e", "n"] + PRIVATE_KEY_FIELDS = ["d", "dp", "dq", "e", "n", "p", "q", "qi"] + REQUIRED_JSON_FIELDS = ["e", "n"] + SSH_PUBLIC_PREFIX = b"ssh-rsa" + + def dumps_private_key(self): + numbers = self.private_key.private_numbers() + return { + "n": int_to_base64(numbers.public_numbers.n), + "e": int_to_base64(numbers.public_numbers.e), + "d": int_to_base64(numbers.d), + "p": int_to_base64(numbers.p), + "q": int_to_base64(numbers.q), + "dp": int_to_base64(numbers.dmp1), + "dq": int_to_base64(numbers.dmq1), + "qi": int_to_base64(numbers.iqmp), + } + + def dumps_public_key(self): + numbers = self.public_key.public_numbers() + return {"n": int_to_base64(numbers.n), "e": int_to_base64(numbers.e)} + + def load_private_key(self): + obj = self._dict_data + + if "oth" in obj: # pragma: no cover + # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 + raise ValueError('"oth" is not supported yet') + + public_numbers = RSAPublicNumbers( + base64_to_int(obj["e"]), base64_to_int(obj["n"]) + ) + + if has_all_prime_factors(obj): + numbers = RSAPrivateNumbers( + d=base64_to_int(obj["d"]), + p=base64_to_int(obj["p"]), + q=base64_to_int(obj["q"]), + dmp1=base64_to_int(obj["dp"]), + dmq1=base64_to_int(obj["dq"]), + iqmp=base64_to_int(obj["qi"]), + public_numbers=public_numbers, + ) + else: + d = base64_to_int(obj["d"]) + p, q = rsa_recover_prime_factors(public_numbers.n, d, public_numbers.e) + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers, + ) + + return numbers.private_key(default_backend()) + + def load_public_key(self): + numbers = RSAPublicNumbers( + base64_to_int(self._dict_data["e"]), base64_to_int(self._dict_data["n"]) + ) + return numbers.public_key(default_backend()) + + @classmethod + def generate_key(cls, key_size=2048, options=None, is_private=False) -> "RSAKey": + if key_size < 512: + raise ValueError("key_size must not be less than 512") + if key_size % 8 != 0: + raise ValueError("Invalid key_size for RSAKey") + raw_key = rsa.generate_private_key( + public_exponent=65537, + key_size=key_size, + backend=default_backend(), + ) + if not is_private: + raw_key = raw_key.public_key() + return cls.import_key(raw_key, options=options) + + @classmethod + def import_dict_key(cls, raw, options=None): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + if "d" in raw and not has_all_prime_factors(raw): + # reload dict key + key.load_raw_key() + key.load_dict_key() + return key + + +def has_all_prime_factors(obj): + props = ["p", "q", "dp", "dq", "qi"] + props_found = [prop in obj for prop in props] + if all(props_found): + return True + + if any(props_found): + raise ValueError( + "RSA key must include all parameters if any are present besides d" + ) + + return False diff --git a/authlib/jose/rfc7518/util.py b/authlib/jose/rfc7518/util.py index d2d13ec1f..723770ad3 100644 --- a/authlib/jose/rfc7518/util.py +++ b/authlib/jose/rfc7518/util.py @@ -3,8 +3,8 @@ def encode_int(num, bits): length = ((bits + 7) // 8) * 2 - padded_hex = '%0*x' % (length, num) - big_endian = binascii.a2b_hex(padded_hex.encode('ascii')) + padded_hex = f"{num:0{length}x}" + big_endian = binascii.a2b_hex(padded_hex.encode("ascii")) return big_endian diff --git a/authlib/jose/rfc7519/__init__.py b/authlib/jose/rfc7519/__init__.py index b98efc941..2717e7f64 100644 --- a/authlib/jose/rfc7519/__init__.py +++ b/authlib/jose/rfc7519/__init__.py @@ -1,16 +1,14 @@ -# -*- coding: utf-8 -*- -""" - authlib.jose.rfc7519 - ~~~~~~~~~~~~~~~~~~~~ +"""authlib.jose.rfc7519. +~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - JSON Web Token (JWT). +This module represents a direct implementation of +JSON Web Token (JWT). - https://tools.ietf.org/html/rfc7519 +https://tools.ietf.org/html/rfc7519 """ +from .claims import BaseClaims +from .claims import JWTClaims from .jwt import JsonWebToken -from .claims import BaseClaims, JWTClaims - -__all__ = ['JsonWebToken', 'BaseClaims', 'JWTClaims'] +__all__ = ["JsonWebToken", "BaseClaims", "JWTClaims"] diff --git a/authlib/jose/rfc7519/claims.py b/authlib/jose/rfc7519/claims.py index 2792c53d3..e9639bc65 100644 --- a/authlib/jose/rfc7519/claims.py +++ b/authlib/jose/rfc7519/claims.py @@ -1,10 +1,9 @@ import time -from authlib.jose.errors import ( - MissingClaimError, - InvalidClaimError, - ExpiredTokenError, - InvalidTokenError, -) + +from authlib.jose.errors import ExpiredTokenError +from authlib.jose.errors import InvalidClaimError +from authlib.jose.errors import InvalidTokenError +from authlib.jose.errors import MissingClaimError class BaseClaims(dict): @@ -35,10 +34,11 @@ class BaseClaims(dict): .. _`OpenID Connect Claims`: http://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests """ + REGISTERED_CLAIMS = [] def __init__(self, payload, header, options=None, params=None): - super(BaseClaims, self).__init__(payload) + super().__init__(payload) self.header = header self.options = options or {} self.params = params or {} @@ -53,28 +53,31 @@ def __getattr__(self, key): def _validate_essential_claims(self): for k in self.options: - if self.options[k].get('essential') and k not in self: - raise MissingClaimError(k) + if self.options[k].get("essential"): + if k not in self: + raise MissingClaimError(k) + elif not self.get(k): + raise InvalidClaimError(k) def _validate_claim_value(self, claim_name): option = self.options.get(claim_name) - value = self.get(claim_name) - if not option or not value: + if not option: return - option_value = option.get('value') + value = self.get(claim_name) + option_value = option.get("value") if option_value and value != option_value: raise InvalidClaimError(claim_name) - option_values = option.get('values') + option_values = option.get("values") if option_values and value not in option_values: raise InvalidClaimError(claim_name) - validate = option.get('validate') + validate = option.get("validate") if validate and not validate(self, value): raise InvalidClaimError(claim_name) - def get_registered_claims(self): + def get_registered_claims(self): # pragma: no cover rv = {} for k in self.REGISTERED_CLAIMS: if k in self: @@ -83,7 +86,7 @@ def get_registered_claims(self): class JWTClaims(BaseClaims): - REGISTERED_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'nbf', 'iat', 'jti'] + REGISTERED_CLAIMS = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"] def validate(self, now=None, leeway=0): """Validate everything in claims payload.""" @@ -100,13 +103,18 @@ def validate(self, now=None, leeway=0): self.validate_iat(now, leeway) self.validate_jti() + # Validate custom claims + for key in self.options.keys(): + if key not in self.REGISTERED_CLAIMS: + self._validate_claim_value(key) + def validate_iss(self): """The "iss" (issuer) claim identifies the principal that issued the JWT. The processing of this claim is generally application specific. The "iss" value is a case-sensitive string containing a StringOrURI value. Use of this claim is OPTIONAL. """ - self._validate_claim_value('iss') + self._validate_claim_value("iss") def validate_sub(self): """The "sub" (subject) claim identifies the principal that is the @@ -117,7 +125,7 @@ def validate_sub(self): "sub" value is a case-sensitive string containing a StringOrURI value. Use of this claim is OPTIONAL. """ - self._validate_claim_value('sub') + self._validate_claim_value("sub") def validate_aud(self): """The "aud" (audience) claim identifies the recipients that the JWT is @@ -132,27 +140,27 @@ def validate_aud(self): interpretation of audience values is generally application specific. Use of this claim is OPTIONAL. """ - aud_option = self.options.get('aud') - aud = self.get('aud') + aud_option = self.options.get("aud") + aud = self.get("aud") if not aud_option or not aud: return - aud_values = aud_option.get('values') + aud_values = aud_option.get("values") if not aud_values: - aud_value = aud_option.get('value') + aud_value = aud_option.get("value") if aud_value: aud_values = [aud_value] if not aud_values: return - if isinstance(self['aud'], list): - aud_list = self['aud'] + if isinstance(self["aud"], list): + aud_list = self["aud"] else: - aud_list = [self['aud']] + aud_list = [self["aud"]] if not any([v in aud_list for v in aud_values]): - raise InvalidClaimError('aud') + raise InvalidClaimError("aud") def validate_exp(self, now, leeway): """The "exp" (expiration time) claim identifies the expiration time on @@ -163,10 +171,10 @@ def validate_exp(self, now, leeway): a few minutes, to account for clock skew. Its value MUST be a number containing a NumericDate value. Use of this claim is OPTIONAL. """ - if 'exp' in self: - exp = self['exp'] - if not isinstance(exp, int): - raise InvalidClaimError('exp') + if "exp" in self: + exp = self["exp"] + if not _validate_numeric_time(exp): + raise InvalidClaimError("exp") if exp < (now - leeway): raise ExpiredTokenError() @@ -179,23 +187,28 @@ def validate_nbf(self, now, leeway): account for clock skew. Its value MUST be a number containing a NumericDate value. Use of this claim is OPTIONAL. """ - if 'nbf' in self: - nbf = self['nbf'] - if not isinstance(nbf, int): - raise InvalidClaimError('nbf') + if "nbf" in self: + nbf = self["nbf"] + if not _validate_numeric_time(nbf): + raise InvalidClaimError("nbf") if nbf > (now + leeway): raise InvalidTokenError() def validate_iat(self, now, leeway): """The "iat" (issued at) claim identifies the time at which the JWT was - issued. This claim can be used to determine the age of the JWT. Its - value MUST be a number containing a NumericDate value. Use of this - claim is OPTIONAL. + issued. This claim can be used to determine the age of the JWT. + Implementers MAY provide for some small leeway, usually no more + than a few minutes, to account for clock skew. Its value MUST be a + number containing a NumericDate value. Use of this claim is OPTIONAL. """ - if 'iat' in self: - iat = self['iat'] - if not isinstance(iat, int): - raise InvalidClaimError('iat') + if "iat" in self: + iat = self["iat"] + if not _validate_numeric_time(iat): + raise InvalidClaimError("iat") + if iat > (now + leeway): + raise InvalidTokenError( + description="The token is not valid as it was issued in the future" + ) def validate_jti(self): """The "jti" (JWT ID) claim provides a unique identifier for the JWT. @@ -207,4 +220,8 @@ def validate_jti(self): to prevent the JWT from being replayed. The "jti" value is a case- sensitive string. Use of this claim is OPTIONAL. """ - self._validate_claim_value('jti') + self._validate_claim_value("jti") + + +def _validate_numeric_time(s): + return isinstance(s, (int, float)) diff --git a/authlib/jose/rfc7519/jwt.py b/authlib/jose/rfc7519/jwt.py index 7ffdebcfb..c52e9df9d 100644 --- a/authlib/jose/rfc7519/jwt.py +++ b/authlib/jose/rfc7519/jwt.py @@ -1,30 +1,40 @@ -import re -import datetime import calendar -from authlib.common.encoding import ( - text_types, to_bytes, to_unicode, - json_loads, json_dumps, -) -from .claims import JWTClaims -from ..errors import DecodeError, InsecureClaimError +import datetime +import random +import re + +from authlib.common.encoding import json_dumps +from authlib.common.encoding import json_loads +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode + +from ..errors import DecodeError +from ..errors import InsecureClaimError from ..rfc7515 import JsonWebSignature from ..rfc7516 import JsonWebEncryption +from ..rfc7517 import Key from ..rfc7517 import KeySet +from .claims import JWTClaims -class JsonWebToken(object): - SENSITIVE_NAMES = ('password', 'token', 'secret', 'secret_key') +class JsonWebToken: + SENSITIVE_NAMES = ("password", "token", "secret", "secret_key") # Thanks to sentry SensitiveDataFilter - SENSITIVE_VALUES = re.compile(r'|'.join([ - # http://www.richardsramblings.com/regex/credit-card-numbers/ - r'\b(?:3[47]\d|(?:4\d|5[1-5]|65)\d{2}|6011)\d{12}\b', - # various private keys - r'-----BEGIN[A-Z ]+PRIVATE KEY-----.+-----END[A-Z ]+PRIVATE KEY-----', - # social security numbers (US) - r'^\b(?!(000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}\b', - ]), re.DOTALL) - - def __init__(self, algorithms=None, private_headers=None): + SENSITIVE_VALUES = re.compile( + r"|".join( + [ + # http://www.richardsramblings.com/regex/credit-card-numbers/ + r"\b(?:3[47]\d|(?:4\d|5[1-5]|65)\d{2}|6011)\d{12}\b", + # various private keys + r"-----BEGIN[A-Z ]+PRIVATE KEY-----.+-----END[A-Z ]+PRIVATE KEY-----", + # social security numbers (US) + r"^\b(?!(000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}\b", + ] + ), + re.DOTALL, + ) + + def __init__(self, algorithms, private_headers=None): self._jws = JsonWebSignature(algorithms, private_headers=private_headers) self._jwe = JsonWebEncryption(algorithms, private_headers=private_headers) @@ -37,7 +47,7 @@ def check_sensitive_data(self, payload): # check claims values v = payload[k] - if isinstance(v, text_types) and self.SENSITIVE_VALUES.search(v): + if isinstance(v, str) and self.SENSITIVE_VALUES.search(v): raise InsecureClaimError(k) def encode(self, header, payload, key, check=True): @@ -49,9 +59,9 @@ def encode(self, header, payload, key, check=True): :param check: check if sensitive data in payload :return: bytes """ - header['typ'] = 'JWT' + header.setdefault("typ", "JWT") - for k in ['exp', 'iat', 'nbf']: + for k in ["exp", "iat", "nbf"]: # convert datetime into timestamp claim = payload.get(k) if isinstance(claim, datetime.datetime): @@ -60,18 +70,15 @@ def encode(self, header, payload, key, check=True): if check: self.check_sensitive_data(payload) - if isinstance(key, KeySet): - key = key.find_by_kid(header.get('kid')) - + key = find_encode_key(key, header) text = to_bytes(json_dumps(payload)) - if 'enc' in header: + if "enc" in header: return self._jwe.serialize_compact(header, text, key) else: return self._jws.serialize_compact(header, text, key) - def decode(self, s, key, claims_cls=None, - claims_options=None, claims_params=None): - """Decode the JWS with the given key. This is similar with + def decode(self, s, key, claims_cls=None, claims_options=None, claims_params=None): + """Decode the JWT with the given key. This is similar with :meth:`verify`, except that it will raise BadSignatureError when signature doesn't match. @@ -86,22 +93,22 @@ def decode(self, s, key, claims_cls=None, if claims_cls is None: claims_cls = JWTClaims - if isinstance(key, KeySet): - def load_key(header, payload): - return key.find_by_kid(header.get('kid')) - else: + if callable(key): load_key = key + else: + load_key = create_load_key(prepare_raw_key(key)) s = to_bytes(s) - dot_count = s.count(b'.') + dot_count = s.count(b".") if dot_count == 2: data = self._jws.deserialize_compact(s, load_key, decode_payload) elif dot_count == 4: data = self._jwe.deserialize_compact(s, load_key, decode_payload) else: - raise DecodeError('Invalid input segments length') + raise DecodeError("Invalid input segments length") return claims_cls( - data['payload'], data['header'], + data["payload"], + data["header"], options=claims_options, params=claims_params, ) @@ -110,8 +117,75 @@ def load_key(header, payload): def decode_payload(bytes_payload): try: payload = json_loads(to_unicode(bytes_payload)) - except ValueError: - raise DecodeError('Invalid payload value') + except ValueError as exc: + raise DecodeError("Invalid payload value") from exc if not isinstance(payload, dict): - raise DecodeError('Invalid payload type') + raise DecodeError("Invalid payload type") return payload + + +def prepare_raw_key(raw): + if isinstance(raw, KeySet): + return raw + + if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"): + raw = json_loads(raw) + elif isinstance(raw, (tuple, list)): + raw = {"keys": raw} + return raw + + +def find_encode_key(key, header): + if isinstance(key, KeySet): + kid = header.get("kid") + if kid: + return key.find_by_kid(kid) + + rv = random.choice(key.keys) + # use side effect to add kid value into header + header["kid"] = rv.kid + return rv + + if isinstance(key, dict) and "keys" in key: + keys = key["keys"] + kid = header.get("kid") + for k in keys: + if k.get("kid") == kid: + return k + + if not kid: + rv = random.choice(keys) + header["kid"] = rv["kid"] + return rv + raise ValueError("Invalid JSON Web Key Set") + + # append kid into header + if isinstance(key, dict) and "kid" in key: + header["kid"] = key["kid"] + elif isinstance(key, Key) and key.kid: + header["kid"] = key.kid + return key + + +def create_load_key(key): + def load_key(header, payload): + if isinstance(key, KeySet): + return key.find_by_kid(header.get("kid")) + + if isinstance(key, dict) and "keys" in key: + keys = key["keys"] + kid = header.get("kid") + + if kid is not None: + # look for the requested key + for k in keys: + if k.get("kid") == kid: + return k + else: + # use the only key + if len(keys) == 1: + return keys[0] + raise ValueError("Invalid JSON Web Key Set") + return key + + return load_key diff --git a/authlib/jose/rfc8037/__init__.py b/authlib/jose/rfc8037/__init__.py index 46a6831ed..2c13c3741 100644 --- a/authlib/jose/rfc8037/__init__.py +++ b/authlib/jose/rfc8037/__init__.py @@ -1,5 +1,4 @@ +from .jws_eddsa import register_jws_rfc8037 from .okp_key import OKPKey -from ._jws_cryptography import register_jws_rfc8037 - -__all__ = ['register_jws_rfc8037', 'OKPKey'] +__all__ = ["register_jws_rfc8037", "OKPKey"] diff --git a/authlib/jose/rfc8037/_jws_cryptography.py b/authlib/jose/rfc8037/_jws_cryptography.py deleted file mode 100644 index 86989fd5c..000000000 --- a/authlib/jose/rfc8037/_jws_cryptography.py +++ /dev/null @@ -1,32 +0,0 @@ -from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.primitives.asymmetric.ed25519 import ( - Ed25519PublicKey, Ed25519PrivateKey -) -from authlib.jose.rfc7515 import JWSAlgorithm, JsonWebSignature -from .okp_key import OKPKey - - -class EdDSAAlgorithm(JWSAlgorithm): - name = 'EdDSA' - description = 'Edwards-curve Digital Signature Algorithm for JWS' - private_key_cls = Ed25519PrivateKey - public_key_cls = Ed25519PublicKey - - def prepare_key(self, raw_data): - return OKPKey.import_key(raw_data) - - def sign(self, msg, key): - op_key = key.get_op_key('sign') - return op_key.sign(msg) - - def verify(self, msg, sig, key): - op_key = key.get_op_key('verify') - try: - op_key.verify(sig, msg) - return True - except InvalidSignature: - return False - - -def register_jws_rfc8037(): - JsonWebSignature.register_algorithm(EdDSAAlgorithm()) diff --git a/authlib/jose/rfc8037/jws_eddsa.py b/authlib/jose/rfc8037/jws_eddsa.py new file mode 100644 index 000000000..e8ab16cc6 --- /dev/null +++ b/authlib/jose/rfc8037/jws_eddsa.py @@ -0,0 +1,28 @@ +from cryptography.exceptions import InvalidSignature + +from ..rfc7515 import JWSAlgorithm +from .okp_key import OKPKey + + +class EdDSAAlgorithm(JWSAlgorithm): + name = "EdDSA" + description = "Edwards-curve Digital Signature Algorithm for JWS" + + def prepare_key(self, raw_data): + return OKPKey.import_key(raw_data) + + def sign(self, msg, key): + op_key = key.get_op_key("sign") + return op_key.sign(msg) + + def verify(self, msg, sig, key): + op_key = key.get_op_key("verify") + try: + op_key.verify(sig, msg) + return True + except InvalidSignature: + return False + + +def register_jws_rfc8037(cls): + cls.register_algorithm(EdDSAAlgorithm()) diff --git a/authlib/jose/rfc8037/okp_key.py b/authlib/jose/rfc8037/okp_key.py index d8438b3bf..034b40d14 100644 --- a/authlib/jose/rfc8037/okp_key.py +++ b/authlib/jose/rfc8037/okp_key.py @@ -1,128 +1,97 @@ -from cryptography.hazmat.primitives.asymmetric.ed25519 import ( - Ed25519PublicKey, Ed25519PrivateKey -) -from cryptography.hazmat.primitives.asymmetric.ed448 import ( - Ed448PublicKey, Ed448PrivateKey -) -from cryptography.hazmat.primitives.asymmetric.x25519 import ( - X25519PublicKey, X25519PrivateKey -) -from cryptography.hazmat.primitives.asymmetric.x448 import ( - X448PublicKey, X448PrivateKey -) -from cryptography.hazmat.primitives.serialization import ( - Encoding, PublicFormat, PrivateFormat, NoEncryption -) -from authlib.common.encoding import ( - to_unicode, to_bytes, - urlsafe_b64decode, urlsafe_b64encode, -) -from authlib.jose.rfc7517 import Key -from ..rfc7518 import import_key, export_key - +from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey +from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey +from cryptography.hazmat.primitives.asymmetric.x448 import X448PrivateKey +from cryptography.hazmat.primitives.asymmetric.x448 import X448PublicKey +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey +from cryptography.hazmat.primitives.serialization import Encoding +from cryptography.hazmat.primitives.serialization import NoEncryption +from cryptography.hazmat.primitives.serialization import PrivateFormat +from cryptography.hazmat.primitives.serialization import PublicFormat + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64decode +from authlib.common.encoding import urlsafe_b64encode + +from ..rfc7517 import AsymmetricKey PUBLIC_KEYS_MAP = { - 'Ed25519': Ed25519PublicKey, - 'Ed448': Ed448PublicKey, - 'X25519': X25519PublicKey, - 'X448': X448PublicKey, + "Ed25519": Ed25519PublicKey, + "Ed448": Ed448PublicKey, + "X25519": X25519PublicKey, + "X448": X448PublicKey, } PRIVATE_KEYS_MAP = { - 'Ed25519': Ed25519PrivateKey, - 'Ed448': Ed448PrivateKey, - 'X25519': X25519PrivateKey, - 'X448': X448PrivateKey, + "Ed25519": Ed25519PrivateKey, + "Ed448": Ed448PrivateKey, + "X25519": X25519PrivateKey, + "X448": X448PrivateKey, } -PUBLIC_KEY_TUPLE = tuple(PUBLIC_KEYS_MAP.values()) -PRIVATE_KEY_TUPLE = tuple(PRIVATE_KEYS_MAP.values()) -class OKPKey(Key): +class OKPKey(AsymmetricKey): """Key class of the ``OKP`` key type.""" - kty = 'OKP' - REQUIRED_JSON_FIELDS = ['crv', 'x'] - RAW_KEY_CLS = ( - Ed25519PublicKey, Ed25519PrivateKey, - Ed448PublicKey, Ed448PrivateKey, - X25519PublicKey, X25519PrivateKey, - X448PublicKey, X448PrivateKey, - ) - - def as_pem(self, is_private=False, password=None): - """Export key into PEM format bytes. - - :param is_private: export private key or public key - :param password: encrypt private key with password - :return: bytes - """ - return export_key(self, is_private=is_private, password=password) + kty = "OKP" + REQUIRED_JSON_FIELDS = ["crv", "x"] + PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS + PRIVATE_KEY_FIELDS = ["crv", "d"] + PUBLIC_KEY_CLS = tuple(PUBLIC_KEYS_MAP.values()) + PRIVATE_KEY_CLS = tuple(PRIVATE_KEYS_MAP.values()) + SSH_PUBLIC_PREFIX = b"ssh-ed25519" def exchange_shared_key(self, pubkey): - # used in ECDHAlgorithm - if isinstance(self.raw_key, (X25519PrivateKey, X448PrivateKey)): - return self.raw_key.exchange(pubkey) - raise ValueError('Invalid key for exchanging shared key') - - @property - def curve_key_size(self): - raise NotImplementedError() + # used in ECDHESAlgorithm + private_key = self.get_private_key() + if private_key and isinstance(private_key, (X25519PrivateKey, X448PrivateKey)): + return private_key.exchange(pubkey) + raise ValueError("Invalid key for exchanging shared key") @staticmethod def get_key_curve(key): if isinstance(key, (Ed25519PublicKey, Ed25519PrivateKey)): - return 'Ed25519' + return "Ed25519" elif isinstance(key, (Ed448PublicKey, Ed448PrivateKey)): - return 'Ed448' + return "Ed448" elif isinstance(key, (X25519PublicKey, X25519PrivateKey)): - return 'X25519' + return "X25519" elif isinstance(key, (X448PublicKey, X448PrivateKey)): - return 'X448' + return "X448" - @staticmethod - def loads_private_key(obj): - crv_key = PRIVATE_KEYS_MAP[obj['crv']] - d_bytes = urlsafe_b64decode(to_bytes(obj['d'])) + def load_private_key(self): + crv_key = PRIVATE_KEYS_MAP[self._dict_data["crv"]] + d_bytes = urlsafe_b64decode(to_bytes(self._dict_data["d"])) return crv_key.from_private_bytes(d_bytes) - @staticmethod - def loads_public_key(obj): - crv_key = PUBLIC_KEYS_MAP[obj['crv']] - x_bytes = urlsafe_b64decode(to_bytes(obj['x'])) + def load_public_key(self): + crv_key = PUBLIC_KEYS_MAP[self._dict_data["crv"]] + x_bytes = urlsafe_b64decode(to_bytes(self._dict_data["x"])) return crv_key.from_public_bytes(x_bytes) - @staticmethod - def dumps_private_key(raw_key): - obj = OKPKey.dumps_public_key(raw_key.public_key()) - d_bytes = raw_key.private_bytes( - Encoding.Raw, - PrivateFormat.Raw, - NoEncryption() + def dumps_private_key(self): + obj = self.dumps_public_key(self.private_key.public_key()) + d_bytes = self.private_key.private_bytes( + Encoding.Raw, PrivateFormat.Raw, NoEncryption() ) - obj['d'] = to_unicode(urlsafe_b64encode(d_bytes)) + obj["d"] = to_unicode(urlsafe_b64encode(d_bytes)) return obj - @staticmethod - def dumps_public_key(raw_key): - x_bytes = raw_key.public_bytes(Encoding.Raw, PublicFormat.Raw) + def dumps_public_key(self, public_key=None): + if public_key is None: + public_key = self.public_key + x_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw) return { - 'crv': OKPKey.get_key_curve(raw_key), - 'x': to_unicode(urlsafe_b64encode(x_bytes)), + "crv": self.get_key_curve(public_key), + "x": to_unicode(urlsafe_b64encode(x_bytes)), } @classmethod - def import_key(cls, raw, options=None): - """Import a key from PEM or dict data.""" - return import_key( - cls, raw, - PUBLIC_KEY_TUPLE, PRIVATE_KEY_TUPLE, - b'ssh-ed25519', options - ) - - @classmethod - def generate_key(cls, crv='Ed25519', options=None, is_private=False): + def generate_key(cls, crv="Ed25519", options=None, is_private=False) -> "OKPKey": if crv not in PRIVATE_KEYS_MAP: - raise ValueError('Invalid crv value: "{}"'.format(crv)) + raise ValueError(f'Invalid crv value: "{crv}"') private_key_cls = PRIVATE_KEYS_MAP[crv] raw_key = private_key_cls.generate() if not is_private: diff --git a/authlib/jose/util.py b/authlib/jose/util.py index 08414cb96..848b95018 100644 --- a/authlib/jose/util.py +++ b/authlib/jose/util.py @@ -1,23 +1,46 @@ import binascii -from authlib.common.encoding import urlsafe_b64decode, json_loads + +from authlib.common.encoding import json_loads +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64decode +from authlib.jose.errors import DecodeError def extract_header(header_segment, error_cls): - header_data = extract_segment(header_segment, error_cls, 'header') + if len(header_segment) > 256000: + raise ValueError("Value of header is too long") + + header_data = extract_segment(header_segment, error_cls, "header") try: - header = json_loads(header_data.decode('utf-8')) + header = json_loads(header_data.decode("utf-8")) except ValueError as e: - raise error_cls('Invalid header string: {}'.format(e)) + raise error_cls(f"Invalid header string: {e}") from e if not isinstance(header, dict): - raise error_cls('Header must be a json object') + raise error_cls("Header must be a json object") return header -def extract_segment(segment, error_cls, name='payload'): +def extract_segment(segment, error_cls, name="payload"): + if len(segment) > 256000: + raise ValueError(f"Value of {name} is too long") + try: return urlsafe_b64decode(segment) - except (TypeError, binascii.Error): - msg = 'Invalid {} padding'.format(name) - raise error_cls(msg) + except (TypeError, binascii.Error) as exc: + msg = f"Invalid {name} padding" + raise error_cls(msg) from exc + + +def ensure_dict(s, structure_name): + if not isinstance(s, dict): + try: + s = json_loads(to_unicode(s)) + except (ValueError, TypeError) as exc: + raise DecodeError(f"Invalid {structure_name}") from exc + + if not isinstance(s, dict): + raise DecodeError(f"Invalid {structure_name}") + + return s diff --git a/authlib/oauth1/__init__.py b/authlib/oauth1/__init__.py index af1ba0792..203b73e4a 100644 --- a/authlib/oauth1/__init__.py +++ b/authlib/oauth1/__init__.py @@ -1,36 +1,31 @@ -# coding: utf-8 - -from .rfc5849 import ( - OAuth1Request, - ClientAuth, - SIGNATURE_HMAC_SHA1, - SIGNATURE_RSA_SHA1, - SIGNATURE_PLAINTEXT, - SIGNATURE_TYPE_HEADER, - SIGNATURE_TYPE_QUERY, - SIGNATURE_TYPE_BODY, - ClientMixin, - TemporaryCredentialMixin, - TokenCredentialMixin, - TemporaryCredential, - AuthorizationServer, - ResourceProtector, -) +from .rfc5849 import SIGNATURE_HMAC_SHA1 +from .rfc5849 import SIGNATURE_PLAINTEXT +from .rfc5849 import SIGNATURE_RSA_SHA1 +from .rfc5849 import SIGNATURE_TYPE_BODY +from .rfc5849 import SIGNATURE_TYPE_HEADER +from .rfc5849 import SIGNATURE_TYPE_QUERY +from .rfc5849 import AuthorizationServer +from .rfc5849 import ClientAuth +from .rfc5849 import ClientMixin +from .rfc5849 import OAuth1Request +from .rfc5849 import ResourceProtector +from .rfc5849 import TemporaryCredential +from .rfc5849 import TemporaryCredentialMixin +from .rfc5849 import TokenCredentialMixin __all__ = [ - 'OAuth1Request', - 'ClientAuth', - 'SIGNATURE_HMAC_SHA1', - 'SIGNATURE_RSA_SHA1', - 'SIGNATURE_PLAINTEXT', - 'SIGNATURE_TYPE_HEADER', - 'SIGNATURE_TYPE_QUERY', - 'SIGNATURE_TYPE_BODY', - - 'ClientMixin', - 'TemporaryCredentialMixin', - 'TokenCredentialMixin', - 'TemporaryCredential', - 'AuthorizationServer', - 'ResourceProtector', + "OAuth1Request", + "ClientAuth", + "SIGNATURE_HMAC_SHA1", + "SIGNATURE_RSA_SHA1", + "SIGNATURE_PLAINTEXT", + "SIGNATURE_TYPE_HEADER", + "SIGNATURE_TYPE_QUERY", + "SIGNATURE_TYPE_BODY", + "ClientMixin", + "TemporaryCredentialMixin", + "TokenCredentialMixin", + "TemporaryCredential", + "AuthorizationServer", + "ResourceProtector", ] diff --git a/authlib/oauth1/client.py b/authlib/oauth1/client.py index 7715711b5..ad523da72 100644 --- a/authlib/oauth1/client.py +++ b/authlib/oauth1/client.py @@ -1,39 +1,48 @@ -# -*- coding: utf-8 -*- -from authlib.common.urls import ( - url_decode, - add_params_to_uri, - urlparse, -) from authlib.common.encoding import json_loads -from .rfc5849 import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_TYPE_HEADER, - ClientAuth, -) +from authlib.common.urls import add_params_to_uri +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from .rfc5849 import SIGNATURE_HMAC_SHA1 +from .rfc5849 import SIGNATURE_TYPE_HEADER +from .rfc5849 import ClientAuth -class OAuth1Client(object): + +class OAuth1Client: auth_class = ClientAuth - def __init__(self, session, client_id, client_secret=None, - token=None, token_secret=None, - redirect_uri=None, rsa_key=None, verifier=None, - signature_method=SIGNATURE_HMAC_SHA1, - signature_type=SIGNATURE_TYPE_HEADER, - force_include_body=False, **kwargs): + def __init__( + self, + session, + client_id, + client_secret=None, + token=None, + token_secret=None, + redirect_uri=None, + rsa_key=None, + verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, + realm=None, + **kwargs, + ): if not client_id: raise ValueError('Missing "client_id"') self.session = session self.auth = self.auth_class( - client_id, client_secret=client_secret, - token=token, token_secret=token_secret, + client_id, + client_secret=client_secret, + token=token, + token_secret=token_secret, redirect_uri=redirect_uri, signature_method=signature_method, signature_type=signature_type, rsa_key=rsa_key, verifier=verifier, - force_include_body=force_include_body + realm=realm, + force_include_body=force_include_body, ) self._kwargs = kwargs @@ -50,7 +59,7 @@ def token(self): return dict( oauth_token=self.auth.token, oauth_token_secret=self.auth.token_secret, - oauth_verifier=self.auth.verifier + oauth_verifier=self.auth.verifier, ) @token.setter @@ -63,15 +72,15 @@ def token(self, token): self.auth.token = None self.auth.token_secret = None self.auth.verifier = None - elif 'oauth_token' in token: - self.auth.token = token['oauth_token'] - if 'oauth_token_secret' in token: - self.auth.token_secret = token['oauth_token_secret'] - if 'oauth_verifier' in token: - self.auth.verifier = token['oauth_verifier'] + elif "oauth_token" in token: + self.auth.token = token["oauth_token"] + if "oauth_token_secret" in token: + self.auth.token_secret = token["oauth_token_secret"] + if "oauth_verifier" in token: + self.auth.verifier = token["oauth_verifier"] else: - message = 'oauth_token is missing: {!r}'.format(token) - self.handle_error('missing_token', message) + message = f"oauth_token is missing: {token!r}" + self.handle_error("missing_token", message) def create_authorization_url(self, url, request_token=None, **kwargs): """Create an authorization URL by appending request_token and optional @@ -87,15 +96,12 @@ def create_authorization_url(self, url, request_token=None, **kwargs): :param kwargs: Optional parameters to append to the URL. :returns: The authorization URL with new parameters embedded. """ - kwargs['oauth_token'] = request_token or self.auth.token + kwargs["oauth_token"] = request_token or self.auth.token if self.auth.redirect_uri: - kwargs['oauth_callback'] = self.auth.redirect_uri - - self.auth.redirect_uri = None - self.auth.realm = None + kwargs["oauth_callback"] = self.auth.redirect_uri return add_params_to_uri(url, kwargs.items()) - def fetch_request_token(self, url, realm=None, **kwargs): + def fetch_request_token(self, url, **kwargs): """Method for fetching an access token from the token endpoint. This is the first step in the OAuth 1 workflow. A request token is @@ -104,23 +110,9 @@ def fetch_request_token(self, url, realm=None, **kwargs): to be used to construct an authorization url. :param url: Request Token endpoint. - :param realm: A string/list/tuple of realm for Authorization header. :param kwargs: Extra parameters to include for fetching token. :return: A Request Token dict. - - Note, ``realm`` can also be configured when session created:: - - session = OAuth1Session(client_id, client_secret, ..., realm='') """ - if realm is None: - realm = self._kwargs.get('realm', None) - if realm: - if isinstance(realm, (tuple, list)): - realm = ' '.join(realm) - self.auth.realm = realm - else: - self.auth.realm = None - return self._fetch_token(url, **kwargs) def fetch_access_token(self, url, verifier=None, **kwargs): @@ -138,7 +130,7 @@ def fetch_access_token(self, url, verifier=None, **kwargs): if verifier: self.auth.verifier = verifier if not self.auth.verifier: - self.handle_error('missing_verifier', 'Missing "verifier" value') + self.handle_error("missing_verifier", 'Missing "verifier" value') return self._fetch_token(url, **kwargs) def parse_authorization_response(self, url): @@ -163,14 +155,13 @@ def _fetch_token(self, url, **kwargs): def parse_response_token(self, status_code, text): if status_code >= 400: message = ( - "Token request failed with code {}, " - "response was '{}'." - ).format(status_code, text) - self.handle_error('fetch_token_denied', message) + f"Token request failed with code {status_code}, response was '{text}'." + ) + self.handle_error("fetch_token_denied", message) try: text = text.strip() - if text.startswith('{'): + if text.startswith("{"): token = json_loads(text) else: token = dict(url_decode(text)) @@ -179,11 +170,17 @@ def parse_response_token(self, status_code, text): "Unable to decode token from token response. " "This is commonly caused by an unsuccessful request where" " a non urlencoded error message is returned. " - "The decoding error was {}" - ).format(e) - raise ValueError(error) + f"The decoding error was {e}" + ) + raise ValueError(error) from e return token @staticmethod def handle_error(error_type, error_description): - raise ValueError('{}: {}'.format(error_type, error_description)) + raise ValueError(f"{error_type}: {error_description}") + + def __del__(self): + try: + del self.session + except AttributeError: + pass diff --git a/authlib/oauth1/rfc5849/__init__.py b/authlib/oauth1/rfc5849/__init__.py index 1f029fbbe..bb7fad8c2 100644 --- a/authlib/oauth1/rfc5849/__init__.py +++ b/authlib/oauth1/rfc5849/__init__.py @@ -1,45 +1,39 @@ -""" - authlib.oauth1.rfc5849 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth1.rfc5849. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of The OAuth 1.0 Protocol. +This module represents a direct implementation of The OAuth 1.0 Protocol. - https://tools.ietf.org/html/rfc5849 +https://tools.ietf.org/html/rfc5849 """ -from .wrapper import OAuth1Request -from .client_auth import ClientAuth -from .signature import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_RSA_SHA1, - SIGNATURE_PLAINTEXT, - SIGNATURE_TYPE_HEADER, - SIGNATURE_TYPE_QUERY, - SIGNATURE_TYPE_BODY, -) -from .models import ( - ClientMixin, - TemporaryCredentialMixin, - TokenCredentialMixin, - TemporaryCredential, -) from .authorization_server import AuthorizationServer +from .client_auth import ClientAuth +from .models import ClientMixin +from .models import TemporaryCredential +from .models import TemporaryCredentialMixin +from .models import TokenCredentialMixin from .resource_protector import ResourceProtector +from .signature import SIGNATURE_HMAC_SHA1 +from .signature import SIGNATURE_PLAINTEXT +from .signature import SIGNATURE_RSA_SHA1 +from .signature import SIGNATURE_TYPE_BODY +from .signature import SIGNATURE_TYPE_HEADER +from .signature import SIGNATURE_TYPE_QUERY +from .wrapper import OAuth1Request __all__ = [ - 'OAuth1Request', - 'ClientAuth', - 'SIGNATURE_HMAC_SHA1', - 'SIGNATURE_RSA_SHA1', - 'SIGNATURE_PLAINTEXT', - 'SIGNATURE_TYPE_HEADER', - 'SIGNATURE_TYPE_QUERY', - 'SIGNATURE_TYPE_BODY', - - 'ClientMixin', - 'TemporaryCredentialMixin', - 'TokenCredentialMixin', - 'TemporaryCredential', - 'AuthorizationServer', - 'ResourceProtector', + "OAuth1Request", + "ClientAuth", + "SIGNATURE_HMAC_SHA1", + "SIGNATURE_RSA_SHA1", + "SIGNATURE_PLAINTEXT", + "SIGNATURE_TYPE_HEADER", + "SIGNATURE_TYPE_QUERY", + "SIGNATURE_TYPE_BODY", + "ClientMixin", + "TemporaryCredentialMixin", + "TokenCredentialMixin", + "TemporaryCredential", + "AuthorizationServer", + "ResourceProtector", ] diff --git a/authlib/oauth1/rfc5849/authorization_server.py b/authlib/oauth1/rfc5849/authorization_server.py index be9b985be..ddbf293b0 100644 --- a/authlib/oauth1/rfc5849/authorization_server.py +++ b/authlib/oauth1/rfc5849/authorization_server.py @@ -1,24 +1,24 @@ -from authlib.common.urls import is_valid_url, add_params_to_uri +from authlib.common.urls import add_params_to_uri +from authlib.common.urls import is_valid_url + from .base_server import BaseServer -from .errors import ( - OAuth1Error, - InvalidRequestError, - MissingRequiredParameterError, - InvalidClientError, - InvalidTokenError, - AccessDeniedError, - MethodNotAllowedError, -) +from .errors import AccessDeniedError +from .errors import InvalidClientError +from .errors import InvalidRequestError +from .errors import InvalidTokenError +from .errors import MethodNotAllowedError +from .errors import MissingRequiredParameterError +from .errors import OAuth1Error class AuthorizationServer(BaseServer): TOKEN_RESPONSE_HEADER = [ - ('Content-Type', 'application/x-www-form-urlencoded'), - ('Cache-Control', 'no-store'), - ('Pragma', 'no-cache'), + ("Content-Type", "application/x-www-form-urlencoded"), + ("Cache-Control", "no-store"), + ("Pragma", "no-cache"), ] - TEMPORARY_CREDENTIALS_METHOD = 'POST' + TEMPORARY_CREDENTIALS_METHOD = "POST" def _get_client(self, request): client = self.get_client_by_id(request.client_id) @@ -33,14 +33,11 @@ def handle_response(self, status_code, payload, headers): def handle_error_response(self, error): return self.handle_response( - error.status_code, - error.get_body(), - error.get_headers() + error.status_code, error.get_body(), error.get_headers() ) def validate_temporary_credentials_request(self, request): """Validate HTTP request for temporary credentials.""" - # The client obtains a set of temporary credentials from the server by # making an authenticated (Section 3) HTTP "POST" request to the # Temporary Credential Request endpoint (unless the server advertises @@ -50,16 +47,16 @@ def validate_temporary_credentials_request(self, request): # REQUIRED parameter if not request.client_id: - raise MissingRequiredParameterError('oauth_consumer_key') + raise MissingRequiredParameterError("oauth_consumer_key") # REQUIRED parameter oauth_callback = request.redirect_uri if not request.redirect_uri: - raise MissingRequiredParameterError('oauth_callback') + raise MissingRequiredParameterError("oauth_callback") # An absolute URI or # other means (the parameter value MUST be set to "oob" - if oauth_callback != 'oob' and not is_valid_url(oauth_callback): + if oauth_callback != "oob" and not is_valid_url(oauth_callback): raise InvalidRequestError('Invalid "oauth_callback" value') client = self._get_client(request) @@ -109,16 +106,16 @@ def create_temporary_credentials_response(self, request=None): credential = self.create_temporary_credential(request) payload = [ - ('oauth_token', credential.get_oauth_token()), - ('oauth_token_secret', credential.get_oauth_token_secret()), - ('oauth_callback_confirmed', True) + ("oauth_token", credential.get_oauth_token()), + ("oauth_token_secret", credential.get_oauth_token_secret()), + ("oauth_callback_confirmed", True), ] return self.handle_response(200, payload, self.TOKEN_RESPONSE_HEADER) def validate_authorization_request(self, request): """Validate the request for resource owner authorization.""" if not request.token: - raise MissingRequiredParameterError('oauth_token') + raise MissingRequiredParameterError("oauth_token") credential = self.get_temporary_credential(request) if not credential: @@ -156,7 +153,7 @@ def create_authorization_response(self, request, grant_user=None): temporary_credentials = request.credential redirect_uri = temporary_credentials.get_redirect_uri() - if not redirect_uri or redirect_uri == 'oob': + if not redirect_uri or redirect_uri == "oob": client_id = temporary_credentials.get_client_id() client = self.get_client_by_id(client_id) redirect_uri = client.get_default_redirect_uri() @@ -164,38 +161,34 @@ def create_authorization_response(self, request, grant_user=None): if grant_user is None: error = AccessDeniedError() location = add_params_to_uri(redirect_uri, error.get_body()) - return self.handle_response(302, '', [('Location', location)]) + return self.handle_response(302, "", [("Location", location)]) request.user = grant_user verifier = self.create_authorization_verifier(request) - params = [ - ('oauth_token', request.token), - ('oauth_verifier', verifier) - ] + params = [("oauth_token", request.token), ("oauth_verifier", verifier)] location = add_params_to_uri(redirect_uri, params) - return self.handle_response(302, '', [('Location', location)]) + return self.handle_response(302, "", [("Location", location)]) def validate_token_request(self, request): """Validate request for issuing token.""" - if not request.client_id: - raise MissingRequiredParameterError('oauth_consumer_key') + raise MissingRequiredParameterError("oauth_consumer_key") client = self._get_client(request) if not client: raise InvalidClientError() if not request.token: - raise MissingRequiredParameterError('oauth_token') + raise MissingRequiredParameterError("oauth_token") token = self.get_temporary_credential(request) if not token: raise InvalidTokenError() - verifier = request.oauth_params.get('oauth_verifier') + verifier = request.oauth_params.get("oauth_verifier") if not verifier: - raise MissingRequiredParameterError('oauth_verifier') + raise MissingRequiredParameterError("oauth_verifier") if not token.check_verifier(verifier): raise InvalidRequestError('Invalid "oauth_verifier"') @@ -252,8 +245,8 @@ def create_token_response(self, request): credential = self.create_token_credential(request) payload = [ - ('oauth_token', credential.get_oauth_token()), - ('oauth_token_secret', credential.get_oauth_token_secret()), + ("oauth_token", credential.get_oauth_token()), + ("oauth_token_secret", credential.get_oauth_token_secret()), ] self.delete_temporary_credential(request) return self.handle_response(200, payload, self.TOKEN_RESPONSE_HEADER) @@ -287,7 +280,7 @@ def get_temporary_credential(self, request): ``TemporaryCredentialMixin``:: def get_temporary_credential(self, request): - key = 'a-key-prefix:{}'.format(request.token) + key = "a-key-prefix:{}".format(request.token) data = cache.get(key) # TemporaryCredential shares methods from TemporaryCredentialMixin return TemporaryCredential(data) @@ -302,7 +295,7 @@ def delete_temporary_credential(self, request): if temporary credential is saved in cache:: def delete_temporary_credential(self, request): - key = 'a-key-prefix:{}'.format(request.token) + key = "a-key-prefix:{}".format(request.token) cache.delete(key) :param request: OAuth1Request instance @@ -317,7 +310,7 @@ def create_authorization_verifier(self, request): verifier = generate_token(36) temporary_credential = request.credential - user_id = request.user.get_user_id() + user_id = request.user.id temporary_credential.user_id = user_id temporary_credential.oauth_verifier = verifier @@ -345,7 +338,7 @@ def create_token_credential(self, request): oauth_token=oauth_token, oauth_token_secret=oauth_token_secret, client_id=temporary_credential.get_client_id(), - user_id=temporary_credential.get_user_id() + user_id=temporary_credential.get_user_id(), ) # if the credential has a save method token_credential.save() diff --git a/authlib/oauth1/rfc5849/base_server.py b/authlib/oauth1/rfc5849/base_server.py index 46898bb20..68bb426be 100644 --- a/authlib/oauth1/rfc5849/base_server.py +++ b/authlib/oauth1/rfc5849/base_server.py @@ -1,24 +1,19 @@ import time -from .signature import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_PLAINTEXT, - SIGNATURE_RSA_SHA1, -) -from .signature import ( - verify_hmac_sha1, - verify_plaintext, - verify_rsa_sha1, -) -from .errors import ( - InvalidRequestError, - MissingRequiredParameterError, - UnsupportedSignatureMethodError, - InvalidNonceError, - InvalidSignatureError, -) - - -class BaseServer(object): + +from .errors import InvalidNonceError +from .errors import InvalidRequestError +from .errors import InvalidSignatureError +from .errors import MissingRequiredParameterError +from .errors import UnsupportedSignatureMethodError +from .signature import SIGNATURE_HMAC_SHA1 +from .signature import SIGNATURE_PLAINTEXT +from .signature import SIGNATURE_RSA_SHA1 +from .signature import verify_hmac_sha1 +from .signature import verify_plaintext +from .signature import verify_rsa_sha1 + + +class BaseServer: SIGNATURE_METHODS = { SIGNATURE_HMAC_SHA1: verify_hmac_sha1, SIGNATURE_RSA_SHA1: verify_rsa_sha1, @@ -40,7 +35,8 @@ def verify_custom_method(request): # verify this request, return True or False return True - Server.register_signature_method('custom-name', verify_custom_method) + + Server.register_signature_method("custom-name", verify_custom_method) """ cls.SIGNATURE_METHODS[name] = verify @@ -49,8 +45,8 @@ def validate_timestamp_and_nonce(self, request): :param request: OAuth1Request instance """ - timestamp = request.oauth_params.get('oauth_timestamp') - nonce = request.oauth_params.get('oauth_nonce') + timestamp = request.oauth_params.get("oauth_timestamp") + nonce = request.oauth_params.get("oauth_nonce") if request.signature_method == SIGNATURE_PLAINTEXT: # The parameters MAY be omitted when using the "PLAINTEXT" @@ -59,7 +55,7 @@ def validate_timestamp_and_nonce(self, request): return if not timestamp: - raise MissingRequiredParameterError('oauth_timestamp') + raise MissingRequiredParameterError("oauth_timestamp") try: # The timestamp value MUST be a positive integer @@ -69,11 +65,11 @@ def validate_timestamp_and_nonce(self, request): if self.EXPIRY_TIME and time.time() - timestamp > self.EXPIRY_TIME: raise InvalidRequestError('Invalid "oauth_timestamp" value') - except (ValueError, TypeError): - raise InvalidRequestError('Invalid "oauth_timestamp" value') + except (ValueError, TypeError) as exc: + raise InvalidRequestError('Invalid "oauth_timestamp" value') from exc if not nonce: - raise MissingRequiredParameterError('oauth_nonce') + raise MissingRequiredParameterError("oauth_nonce") if self.exists_nonce(nonce, request): raise InvalidNonceError() @@ -85,13 +81,13 @@ def validate_oauth_signature(self, request): """ method = request.signature_method if not method: - raise MissingRequiredParameterError('oauth_signature_method') + raise MissingRequiredParameterError("oauth_signature_method") if method not in self.SUPPORTED_SIGNATURE_METHODS: raise UnsupportedSignatureMethodError() if not request.signature: - raise MissingRequiredParameterError('oauth_signature') + raise MissingRequiredParameterError("oauth_signature") verify = self.SIGNATURE_METHODS.get(method) if not verify: diff --git a/authlib/oauth1/rfc5849/client_auth.py b/authlib/oauth1/rfc5849/client_auth.py index 504a3523e..d1d73bb83 100644 --- a/authlib/oauth1/rfc5849/client_auth.py +++ b/authlib/oauth1/rfc5849/client_auth.py @@ -1,33 +1,30 @@ +import base64 +import hashlib import time + +from authlib.common.encoding import to_native from authlib.common.security import generate_token from authlib.common.urls import extract_params -from authlib.common.encoding import to_native + +from .parameters import prepare_form_encoded_body +from .parameters import prepare_headers +from .parameters import prepare_request_uri_query +from .signature import SIGNATURE_HMAC_SHA1 +from .signature import SIGNATURE_PLAINTEXT +from .signature import SIGNATURE_RSA_SHA1 +from .signature import SIGNATURE_TYPE_BODY +from .signature import SIGNATURE_TYPE_HEADER +from .signature import SIGNATURE_TYPE_QUERY +from .signature import sign_hmac_sha1 +from .signature import sign_plaintext +from .signature import sign_rsa_sha1 from .wrapper import OAuth1Request -from .signature import ( - SIGNATURE_HMAC_SHA1, - SIGNATURE_PLAINTEXT, - SIGNATURE_RSA_SHA1, - SIGNATURE_TYPE_HEADER, - SIGNATURE_TYPE_BODY, - SIGNATURE_TYPE_QUERY, -) -from .signature import ( - sign_hmac_sha1, - sign_rsa_sha1, - sign_plaintext -) -from .parameters import ( - prepare_form_encoded_body, - prepare_headers, - prepare_request_uri_query, -) - - -CONTENT_TYPE_FORM_URLENCODED = 'application/x-www-form-urlencoded' -CONTENT_TYPE_MULTI_PART = 'multipart/form-data' - - -class ClientAuth(object): + +CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded" +CONTENT_TYPE_MULTI_PART = "multipart/form-data" + + +class ClientAuth: SIGNATURE_METHODS = { SIGNATURE_HMAC_SHA1: sign_hmac_sha1, SIGNATURE_RSA_SHA1: sign_rsa_sha1, @@ -45,18 +42,27 @@ def register_signature_method(cls, name, sign): def custom_sign_method(client, request): # client is the instance of Client. - return 'your-signed-string' + return "your-signed-string" + - Client.register_signature_method('custom-name', custom_sign_method) + Client.register_signature_method("custom-name", custom_sign_method) """ cls.SIGNATURE_METHODS[name] = sign - def __init__(self, client_id, client_secret=None, - token=None, token_secret=None, - redirect_uri=None, rsa_key=None, verifier=None, - signature_method=SIGNATURE_HMAC_SHA1, - signature_type=SIGNATURE_TYPE_HEADER, - realm=None, force_include_body=False): + def __init__( + self, + client_id, + client_secret=None, + token=None, + token_secret=None, + redirect_uri=None, + rsa_key=None, + verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + realm=None, + force_include_body=False, + ): self.client_id = client_id self.client_secret = client_secret self.token = token @@ -70,7 +76,7 @@ def __init__(self, client_id, client_secret=None, self.force_include_body = force_include_body def get_oauth_signature(self, method, uri, headers, body): - """Get an OAuth signature to be used in signing a request + """Get an OAuth signature to be used in signing a request. To satisfy `section 3.4.1.2`_ item 2, if the request argument's headers dict attribute contains a Host item, its value will @@ -81,60 +87,54 @@ def get_oauth_signature(self, method, uri, headers, body): """ sign = self.SIGNATURE_METHODS.get(self.signature_method) if not sign: - raise ValueError('Invalid signature method.') + raise ValueError("Invalid signature method.") request = OAuth1Request(method, uri, body=body, headers=headers) return sign(self, request) def get_oauth_params(self, nonce, timestamp): oauth_params = [ - ('oauth_nonce', nonce), - ('oauth_timestamp', timestamp), - ('oauth_version', '1.0'), - ('oauth_signature_method', self.signature_method), - ('oauth_consumer_key', self.client_id), + ("oauth_nonce", nonce), + ("oauth_timestamp", timestamp), + ("oauth_version", "1.0"), + ("oauth_signature_method", self.signature_method), + ("oauth_consumer_key", self.client_id), ] if self.token: - oauth_params.append(('oauth_token', self.token)) + oauth_params.append(("oauth_token", self.token)) if self.redirect_uri: - oauth_params.append(('oauth_callback', self.redirect_uri)) + oauth_params.append(("oauth_callback", self.redirect_uri)) if self.verifier: - oauth_params.append(('oauth_verifier', self.verifier)) + oauth_params.append(("oauth_verifier", self.verifier)) return oauth_params def _render(self, uri, headers, body, oauth_params): if self.signature_type == SIGNATURE_TYPE_HEADER: headers = prepare_headers(oauth_params, headers, realm=self.realm) elif self.signature_type == SIGNATURE_TYPE_BODY: - if CONTENT_TYPE_FORM_URLENCODED in headers.get('Content-Type', ''): + if CONTENT_TYPE_FORM_URLENCODED in headers.get("Content-Type", ""): decoded_body = extract_params(body) or [] body = prepare_form_encoded_body(oauth_params, decoded_body) - headers['Content-Type'] = CONTENT_TYPE_FORM_URLENCODED + headers["Content-Type"] = CONTENT_TYPE_FORM_URLENCODED elif self.signature_type == SIGNATURE_TYPE_QUERY: uri = prepare_request_uri_query(oauth_params, uri) else: - raise ValueError('Unknown signature type specified.') + raise ValueError("Unknown signature type specified.") return uri, headers, body - def sign(self, method, uri, headers, body, nonce=None, timestamp=None): + def sign(self, method, uri, headers, body): """Sign the HTTP request, add OAuth parameters and signature. :param method: HTTP method of the request. :param uri: URI of the HTTP request. :param body: Body payload of the HTTP request. :param headers: Headers of the HTTP request. - :param nonce: A string to represent nonce value. If not configured, - this method will generate one for you. - :param timestamp: Current timestamp. If not configured, this method - will generate one for you. :return: uri, headers, body """ - if nonce is None: - nonce = generate_nonce() - if timestamp is None: - timestamp = generate_timestamp() + nonce = generate_nonce() + timestamp = generate_timestamp() if body is None: - body = '' + body = b"" # transform int to str timestamp = str(timestamp) @@ -143,10 +143,17 @@ def sign(self, method, uri, headers, body, nonce=None, timestamp=None): headers = {} oauth_params = self.get_oauth_params(nonce, timestamp) + + # https://datatracker.ietf.org/doc/html/draft-eaton-oauth-bodyhash-00.html + # include oauth_body_hash + if body and headers.get("Content-Type") != CONTENT_TYPE_FORM_URLENCODED: + oauth_body_hash = base64.b64encode(hashlib.sha1(body).digest()) + oauth_params.append(("oauth_body_hash", oauth_body_hash.decode("utf-8"))) + uri, headers, body = self._render(uri, headers, body, oauth_params) sig = self.get_oauth_signature(method, uri, headers, body) - oauth_params.append(('oauth_signature', sig)) + oauth_params.append(("oauth_signature", sig)) uri, headers, body = self._render(uri, headers, body, oauth_params) return uri, headers, body @@ -157,22 +164,24 @@ def prepare(self, method, uri, headers, body): Parameters may be included from the body if the content-type is urlencoded, if no content type is set, a guess is made. """ - content_type = to_native(headers.get('Content-Type', '')) + content_type = to_native(headers.get("Content-Type", "")) if self.signature_type == SIGNATURE_TYPE_BODY: content_type = CONTENT_TYPE_FORM_URLENCODED elif not content_type and extract_params(body): content_type = CONTENT_TYPE_FORM_URLENCODED if CONTENT_TYPE_FORM_URLENCODED in content_type: - headers['Content-Type'] = CONTENT_TYPE_FORM_URLENCODED + headers["Content-Type"] = CONTENT_TYPE_FORM_URLENCODED + if isinstance(body, bytes): + body = body.decode() uri, headers, body = self.sign(method, uri, headers, body) elif self.force_include_body: # To allow custom clients to work on non form encoded bodies. uri, headers, body = self.sign(method, uri, headers, body) else: # Omit body data in the signing of non form-encoded requests - uri, headers, _ = self.sign(method, uri, headers, '') - body = '' + uri, headers, _ = self.sign(method, uri, headers, b"") + body = b"" return uri, headers, body diff --git a/authlib/oauth1/rfc5849/errors.py b/authlib/oauth1/rfc5849/errors.py index 14918331f..9826aec60 100644 --- a/authlib/oauth1/rfc5849/errors.py +++ b/authlib/oauth1/rfc5849/errors.py @@ -1,34 +1,32 @@ -""" - authlib.oauth1.rfc5849.errors - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth1.rfc5849.errors. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - RFC5849 has no definition on errors. This module is designed by - Authlib based on OAuth 1.0a `Section 10`_ with some changes. +RFC5849 has no definition on errors. This module is designed by +Authlib based on OAuth 1.0a `Section 10`_ with some changes. - .. _`Section 10`: https://oauth.net/core/1.0a/#rfc.section.10 +.. _`Section 10`: https://oauth.net/core/1.0a/#rfc.section.10 """ + from authlib.common.errors import AuthlibHTTPError from authlib.common.security import is_secure_transport class OAuth1Error(AuthlibHTTPError): def __init__(self, description=None, uri=None, status_code=None): - super(OAuth1Error, self).__init__(None, description, uri, status_code) + super().__init__(None, description, uri, status_code) def get_headers(self): """Get a list of headers.""" return [ - ('Content-Type', 'application/x-www-form-urlencoded'), - ('Cache-Control', 'no-store'), - ('Pragma', 'no-cache') + ("Content-Type", "application/x-www-form-urlencoded"), + ("Cache-Control", "no-store"), + ("Pragma", "no-cache"), ] class InsecureTransportError(OAuth1Error): - error = 'insecure_transport' - - def get_error_description(self): - return self.gettext('OAuth 2 MUST utilize https.') + error = "insecure_transport" + description = "OAuth 2 MUST utilize https." @classmethod def check(cls, uri): @@ -37,64 +35,55 @@ def check(cls, uri): class InvalidRequestError(OAuth1Error): - error = 'invalid_request' + error = "invalid_request" class UnsupportedParameterError(OAuth1Error): - error = 'unsupported_parameter' + error = "unsupported_parameter" class UnsupportedSignatureMethodError(OAuth1Error): - error = 'unsupported_signature_method' + error = "unsupported_signature_method" class MissingRequiredParameterError(OAuth1Error): - error = 'missing_required_parameter' + error = "missing_required_parameter" def __init__(self, key): - super(MissingRequiredParameterError, self).__init__() - self._key = key - - def get_error_description(self): - return self.gettext( - 'missing "%(key)s" in parameters') % dict(key=self._key) + description = f'missing "{key}" in parameters' + super().__init__(description=description) class DuplicatedOAuthProtocolParameterError(OAuth1Error): - error = 'duplicated_oauth_protocol_parameter' + error = "duplicated_oauth_protocol_parameter" class InvalidClientError(OAuth1Error): - error = 'invalid_client' + error = "invalid_client" status_code = 401 class InvalidTokenError(OAuth1Error): - error = 'invalid_token' + error = "invalid_token" + description = 'Invalid or expired "oauth_token" in parameters' status_code = 401 - def get_error_description(self): - return self.gettext('Invalid or expired "oauth_token" in parameters') - class InvalidSignatureError(OAuth1Error): - error = 'invalid_signature' + error = "invalid_signature" status_code = 401 class InvalidNonceError(OAuth1Error): - error = 'invalid_nonce' + error = "invalid_nonce" status_code = 401 class AccessDeniedError(OAuth1Error): - error = 'access_denied' - - def get_error_description(self): - return self.gettext( - 'The resource owner or authorization server denied the request') + error = "access_denied" + description = "The resource owner or authorization server denied the request" class MethodNotAllowedError(OAuth1Error): - error = 'method_not_allowed' + error = "method_not_allowed" status_code = 405 diff --git a/authlib/oauth1/rfc5849/models.py b/authlib/oauth1/rfc5849/models.py index 76befe9d3..04245d166 100644 --- a/authlib/oauth1/rfc5849/models.py +++ b/authlib/oauth1/rfc5849/models.py @@ -1,5 +1,4 @@ - -class ClientMixin(object): +class ClientMixin: def get_default_redirect_uri(self): """A method to get client default redirect_uri. For instance, the database table for client has a column called ``default_redirect_uri``:: @@ -30,7 +29,7 @@ def get_rsa_public_key(self): raise NotImplementedError() -class TokenCredentialMixin(object): +class TokenCredentialMixin: def get_oauth_token(self): """A method to get the value of ``oauth_token``. For instance, the database table has a column called ``oauth_token``:: @@ -91,19 +90,19 @@ def check_verifier(self, verifier): class TemporaryCredential(dict, TemporaryCredentialMixin): def get_client_id(self): - return self.get('client_id') + return self.get("client_id") def get_user_id(self): - return self.get('user_id') + return self.get("user_id") def get_redirect_uri(self): - return self.get('oauth_callback') + return self.get("oauth_callback") def check_verifier(self, verifier): - return self.get('oauth_verifier') == verifier + return self.get("oauth_verifier") == verifier def get_oauth_token(self): - return self.get('oauth_token') + return self.get("oauth_token") def get_oauth_token_secret(self): - return self.get('oauth_token_secret') + return self.get("oauth_token_secret") diff --git a/authlib/oauth1/rfc5849/parameters.py b/authlib/oauth1/rfc5849/parameters.py index 4746aeaad..545742440 100644 --- a/authlib/oauth1/rfc5849/parameters.py +++ b/authlib/oauth1/rfc5849/parameters.py @@ -1,14 +1,15 @@ -# -*- coding: utf-8 -*- +"""authlib.spec.rfc5849.parameters. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This module contains methods related to `section 3.5`_ of the OAuth 1.0a spec. + +.. _`section 3.5`: https://tools.ietf.org/html/rfc5849#section-3.5 """ - authlib.spec.rfc5849.parameters - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - This module contains methods related to `section 3.5`_ of the OAuth 1.0a spec. +from authlib.common.urls import extract_params +from authlib.common.urls import url_encode +from authlib.common.urls import urlparse - .. _`section 3.5`: https://tools.ietf.org/html/rfc5849#section-3.5 -""" -from authlib.common.urls import urlparse, url_encode, extract_params from .util import escape @@ -37,10 +38,13 @@ def prepare_headers(oauth_params, headers=None, realm=None): headers = headers or {} # step 1, 2, 3 in Section 3.5.1 - header_parameters = ', '.join([ - '{0}="{1}"'.format(escape(k), escape(v)) for k, v in oauth_params - if k.startswith('oauth_') - ]) + header_parameters = ", ".join( + [ + f'{escape(k)}="{escape(v)}"' + for k, v in oauth_params + if k.startswith("oauth_") + ] + ) # 4. The OPTIONAL "realm" parameter MAY be added and interpreted per # `RFC2617 section 1.2`_. @@ -48,10 +52,10 @@ def prepare_headers(oauth_params, headers=None, realm=None): # .. _`RFC2617 section 1.2`: https://tools.ietf.org/html/rfc2617#section-1.2 if realm: # NOTE: realm should *not* be escaped - header_parameters = 'realm="{}", '.format(realm) + header_parameters + header_parameters = f'realm="{realm}", ' + header_parameters # the auth-scheme name set to "OAuth" (case insensitive). - headers['Authorization'] = 'OAuth {}'.format(header_parameters) + headers["Authorization"] = f"OAuth {header_parameters}" return headers @@ -72,7 +76,7 @@ def _append_params(oauth_params, params): # parameters, in which case, the protocol parameters SHOULD be appended # following the request-specific parameters, properly separated by an "&" # character (ASCII code 38) - merged.sort(key=lambda i: i[0].startswith('oauth_')) + merged.sort(key=lambda i: i[0].startswith("oauth_")) return merged @@ -98,6 +102,5 @@ def prepare_request_uri_query(oauth_params, uri): """ # append OAuth params to the existing set of query components sch, net, path, par, query, fra = urlparse.urlparse(uri) - query = url_encode( - _append_params(oauth_params, extract_params(query) or [])) + query = url_encode(_append_params(oauth_params, extract_params(query) or [])) return urlparse.urlunparse((sch, net, path, par, query, fra)) diff --git a/authlib/oauth1/rfc5849/resource_protector.py b/authlib/oauth1/rfc5849/resource_protector.py index 2b5d7819c..364b6b5a6 100644 --- a/authlib/oauth1/rfc5849/resource_protector.py +++ b/authlib/oauth1/rfc5849/resource_protector.py @@ -1,10 +1,8 @@ from .base_server import BaseServer +from .errors import InvalidClientError +from .errors import InvalidTokenError +from .errors import MissingRequiredParameterError from .wrapper import OAuth1Request -from .errors import ( - MissingRequiredParameterError, - InvalidClientError, - InvalidTokenError, -) class ResourceProtector(BaseServer): @@ -12,7 +10,7 @@ def validate_request(self, method, uri, body, headers): request = OAuth1Request(method, uri, body, headers) if not request.client_id: - raise MissingRequiredParameterError('oauth_consumer_key') + raise MissingRequiredParameterError("oauth_consumer_key") client = self.get_client_by_id(request.client_id) if not client: @@ -20,7 +18,7 @@ def validate_request(self, method, uri, body, headers): request.client = client if not request.token: - raise MissingRequiredParameterError('oauth_token') + raise MissingRequiredParameterError("oauth_token") token = self.get_token_credential(request) if not token: diff --git a/authlib/oauth1/rfc5849/rsa.py b/authlib/oauth1/rfc5849/rsa.py index 3785b0f79..fd68fcd2f 100644 --- a/authlib/oauth1/rfc5849/rsa.py +++ b/authlib/oauth1/rfc5849/rsa.py @@ -1,27 +1,22 @@ -from cryptography.hazmat.primitives import hashes +from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.serialization import ( - load_pem_private_key, load_pem_public_key -) +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.hazmat.primitives.serialization import load_pem_public_key + from authlib.common.encoding import to_bytes def sign_sha1(msg, rsa_private_key): key = load_pem_private_key( - to_bytes(rsa_private_key), - password=None, - backend=default_backend() + to_bytes(rsa_private_key), password=None, backend=default_backend() ) return key.sign(msg, padding.PKCS1v15(), hashes.SHA1()) def verify_sha1(sig, msg, rsa_public_key): - key = load_pem_public_key( - to_bytes(rsa_public_key), - backend=default_backend() - ) + key = load_pem_public_key(to_bytes(rsa_public_key), backend=default_backend()) try: key.verify(sig, msg, padding.PKCS1v15(), hashes.SHA1()) return True diff --git a/authlib/oauth1/rfc5849/signature.py b/authlib/oauth1/rfc5849/signature.py index 6ba67e2d5..d12e44a53 100644 --- a/authlib/oauth1/rfc5849/signature.py +++ b/authlib/oauth1/rfc5849/signature.py @@ -1,26 +1,29 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth1.rfc5849.signature - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth1.rfc5849.signature. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of `section 3.4`_ of the spec. +This module represents a direct implementation of `section 3.4`_ of the spec. - .. _`section 3.4`: https://tools.ietf.org/html/rfc5849#section-3.4 +.. _`section 3.4`: https://tools.ietf.org/html/rfc5849#section-3.4 """ + import binascii import hashlib import hmac + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode from authlib.common.urls import urlparse -from authlib.common.encoding import to_unicode, to_bytes -from .util import escape, unescape + +from .util import escape +from .util import unescape SIGNATURE_HMAC_SHA1 = "HMAC-SHA1" SIGNATURE_RSA_SHA1 = "RSA-SHA1" SIGNATURE_PLAINTEXT = "PLAINTEXT" -SIGNATURE_TYPE_HEADER = 'HEADER' -SIGNATURE_TYPE_QUERY = 'QUERY' -SIGNATURE_TYPE_BODY = 'BODY' +SIGNATURE_TYPE_HEADER = "HEADER" +SIGNATURE_TYPE_QUERY = "QUERY" +SIGNATURE_TYPE_BODY = "BODY" def construct_base_string(method, uri, params, host=None): @@ -52,7 +55,6 @@ def construct_base_string(method, uri, params, host=None): .. _`Section 3.4.1`: https://tools.ietf.org/html/rfc5849#section-3.4.1 """ - # Create base string URI per Section 3.4.1.2 base_string_uri = normalize_base_string_uri(uri, host) @@ -60,11 +62,11 @@ def construct_base_string(method, uri, params, host=None): unescaped_params = [] for k, v in params: # The "oauth_signature" parameter MUST be excluded from the signature - if k in ('oauth_signature', 'realm'): + if k in ("oauth_signature", "realm"): continue # ensure oauth params are unescaped - if k.startswith('oauth_'): + if k.startswith("oauth_"): v = unescape(v) unescaped_params.append((k, v)) @@ -72,11 +74,13 @@ def construct_base_string(method, uri, params, host=None): normalized_params = normalize_parameters(unescaped_params) # construct base string - return '&'.join([ - escape(method.upper()), - escape(base_string_uri), - escape(normalized_params), - ]) + return "&".join( + [ + escape(method.upper()), + escape(base_string_uri), + escape(normalized_params), + ] + ) def normalize_base_string_uri(uri, host=None): @@ -110,7 +114,7 @@ def normalize_base_string_uri(uri, host=None): # .. _`RFC3986`: https://tools.ietf.org/html/rfc3986 if not scheme or not netloc: - raise ValueError('uri must include a scheme and netloc') + raise ValueError("uri must include a scheme and netloc") # Per `RFC 2616 section 5.1.2`_: # @@ -119,7 +123,7 @@ def normalize_base_string_uri(uri, host=None): # # .. _`RFC 2616 section 5.1.2`: https://tools.ietf.org/html/rfc2616#section-5.1.2 if not path: - path = '/' + path = "/" # 1. The scheme and host MUST be in lowercase. scheme = scheme.lower() @@ -139,15 +143,15 @@ def normalize_base_string_uri(uri, host=None): # .. _`RFC2616`: https://tools.ietf.org/html/rfc2616 # .. _`RFC2818`: https://tools.ietf.org/html/rfc2818 default_ports = ( - ('http', '80'), - ('https', '443'), + ("http", "80"), + ("https", "443"), ) - if ':' in netloc: - host, port = netloc.split(':', 1) + if ":" in netloc: + host, port = netloc.split(":", 1) if (scheme, port) in default_ports: netloc = host - return urlparse.urlunparse((scheme, netloc, path, params, '', '')) + return urlparse.urlunparse((scheme, netloc, path, params, "", "")) def normalize_parameters(params): @@ -219,7 +223,6 @@ def normalize_parameters(params): .. _`Section 3.4.1.3.2`: https://tools.ietf.org/html/rfc5849#section-3.4.1.3.2 """ - # 1. First, the name and value of each parameter are encoded # (`Section 3.6`_). # @@ -234,19 +237,18 @@ def normalize_parameters(params): # 3. The name of each parameter is concatenated to its corresponding # value using an "=" character (ASCII code 61) as a separator, even # if the value is empty. - parameter_parts = ['{0}={1}'.format(k, v) for k, v in key_values] + parameter_parts = [f"{k}={v}" for k, v in key_values] # 4. The sorted name/value pairs are concatenated together into a # single string by using an "&" character (ASCII code 38) as # separator. - return '&'.join(parameter_parts) + return "&".join(parameter_parts) def generate_signature_base_string(request): """Generate signature base string from request.""" - host = request.headers.get('Host', None) - return construct_base_string( - request.method, request.uri, request.params, host) + host = request.headers.get("Host", None) + return construct_base_string(request.method, request.uri, request.params, host) def hmac_sha1_signature(base_string, client_secret, token_secret): @@ -255,12 +257,11 @@ def hmac_sha1_signature(base_string, client_secret, token_secret): The "HMAC-SHA1" signature method uses the HMAC-SHA1 signature algorithm as defined in `RFC2104`_:: - digest = HMAC-SHA1 (key, text) + digest = HMAC - SHA1(key, text) .. _`RFC2104`: https://tools.ietf.org/html/rfc2104 .. _`Section 3.4.2`: https://tools.ietf.org/html/rfc5849#section-3.4.2 """ - # The HMAC-SHA1 function variables are used in following way: # text is set to the value of the signature base string from @@ -273,16 +274,16 @@ def hmac_sha1_signature(base_string, client_secret, token_secret): # 1. The client shared-secret, after being encoded (`Section 3.6`_). # # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 - key = escape(client_secret or '') + key = escape(client_secret or "") # 2. An "&" character (ASCII code 38), which MUST be included # even when either secret is empty. - key += '&' + key += "&" # 3. The token shared-secret, after being encoded (`Section 3.6`_). # # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 - key += escape(token_secret or '') + key += escape(token_secret or "") signature = hmac.new(to_bytes(key), to_bytes(text), hashlib.sha1) @@ -309,6 +310,7 @@ def rsa_sha1_signature(base_string, rsa_private_key): .. _`RFC3447, Section 8.2`: https://tools.ietf.org/html/rfc3447#section-8.2 """ from .rsa import sign_sha1 + base_string = to_bytes(base_string) s = sign_sha1(to_bytes(base_string), rsa_private_key) sig = binascii.b2a_base64(s)[:-1] @@ -326,23 +328,22 @@ def plaintext_signature(client_secret, token_secret): .. _`Section 3.4.4`: https://tools.ietf.org/html/rfc5849#section-3.4.4 """ - # The "oauth_signature" protocol parameter is set to the concatenated # value of: # 1. The client shared-secret, after being encoded (`Section 3.6`_). # # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 - signature = escape(client_secret or '') + signature = escape(client_secret or "") # 2. An "&" character (ASCII code 38), which MUST be included even # when either secret is empty. - signature += '&' + signature += "&" # 3. The token shared-secret, after being encoded (`Section 3.6`_). # # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 - signature += escape(token_secret or '') + signature += escape(token_secret or "") return signature @@ -350,8 +351,7 @@ def plaintext_signature(client_secret, token_secret): def sign_hmac_sha1(client, request): """Sign a HMAC-SHA1 signature.""" base_string = generate_signature_base_string(request) - return hmac_sha1_signature( - base_string, client.client_secret, client.token_secret) + return hmac_sha1_signature(base_string, client.client_secret, client.token_secret) def sign_rsa_sha1(client, request): @@ -368,14 +368,14 @@ def sign_plaintext(client, request): def verify_hmac_sha1(request): """Verify a HMAC-SHA1 signature.""" base_string = generate_signature_base_string(request) - sig = hmac_sha1_signature( - base_string, request.client_secret, request.token_secret) + sig = hmac_sha1_signature(base_string, request.client_secret, request.token_secret) return hmac.compare_digest(sig, request.signature) def verify_rsa_sha1(request): """Verify a RSASSA-PKCS #1 v1.5 base64 encoded signature.""" from .rsa import verify_sha1 + base_string = generate_signature_base_string(request) sig = binascii.a2b_base64(to_bytes(request.signature)) return verify_sha1(sig, to_bytes(base_string), request.rsa_public_key) diff --git a/authlib/oauth1/rfc5849/util.py b/authlib/oauth1/rfc5849/util.py index 9383e22ed..fb1e0ca34 100644 --- a/authlib/oauth1/rfc5849/util.py +++ b/authlib/oauth1/rfc5849/util.py @@ -1,8 +1,9 @@ -from authlib.common.urls import quote, unquote +from authlib.common.urls import quote +from authlib.common.urls import unquote def escape(s): - return quote(s, safe=b'~') + return quote(s, safe=b"~") def unescape(s): diff --git a/authlib/oauth1/rfc5849/wrapper.py b/authlib/oauth1/rfc5849/wrapper.py index 9f889a30c..cd3c43e71 100644 --- a/authlib/oauth1/rfc5849/wrapper.py +++ b/authlib/oauth1/rfc5849/wrapper.py @@ -1,20 +1,19 @@ -from authlib.common.urls import ( - urlparse, extract_params, url_decode, - parse_http_list, parse_keqv_list, -) -from .signature import ( - SIGNATURE_TYPE_QUERY, - SIGNATURE_TYPE_BODY, - SIGNATURE_TYPE_HEADER -) -from .errors import ( - InsecureTransportError, - DuplicatedOAuthProtocolParameterError -) +from urllib.request import parse_http_list +from urllib.request import parse_keqv_list + +from authlib.common.urls import extract_params +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse + +from .errors import DuplicatedOAuthProtocolParameterError +from .errors import InsecureTransportError +from .signature import SIGNATURE_TYPE_BODY +from .signature import SIGNATURE_TYPE_HEADER +from .signature import SIGNATURE_TYPE_QUERY from .util import unescape -class OAuth1Request(object): +class OAuth1Request: def __init__(self, method, uri, body=None, headers=None): InsecureTransportError.check(uri) self.method = method @@ -33,7 +32,8 @@ def __init__(self, method, uri, body=None, headers=None): self.auth_params, self.realm = _parse_authorization_header(headers) self.signature_type, self.oauth_params = _parse_oauth_params( - self.query_params, self.body_params, self.auth_params) + self.query_params, self.body_params, self.auth_params + ) params = [] params.extend(self.query_params) @@ -43,7 +43,7 @@ def __init__(self, method, uri, body=None, headers=None): @property def client_id(self): - return self.oauth_params.get('oauth_consumer_key') + return self.oauth_params.get("oauth_consumer_key") @property def client_secret(self): @@ -57,23 +57,23 @@ def rsa_public_key(self): @property def timestamp(self): - return self.oauth_params.get('oauth_timestamp') + return self.oauth_params.get("oauth_timestamp") @property def redirect_uri(self): - return self.oauth_params.get('oauth_callback') + return self.oauth_params.get("oauth_callback") @property def signature(self): - return self.oauth_params.get('oauth_signature') + return self.oauth_params.get("oauth_signature") @property def signature_method(self): - return self.oauth_params.get('oauth_signature_method') + return self.oauth_params.get("oauth_signature_method") @property def token(self): - return self.oauth_params.get('oauth_token') + return self.oauth_params.get("oauth_token") @property def token_secret(self): @@ -83,41 +83,41 @@ def token_secret(self): def _filter_oauth(params): for k, v in params: - if k.startswith('oauth_'): + if k.startswith("oauth_"): yield (k, v) def _parse_authorization_header(headers): - """Parse an OAuth authorization header into a list of 2-tuples""" - authorization_header = headers.get('Authorization') + """Parse an OAuth authorization header into a list of 2-tuples.""" + authorization_header = headers.get("Authorization") if not authorization_header: return [], None - auth_scheme = 'oauth ' + auth_scheme = "oauth " if authorization_header.lower().startswith(auth_scheme): - items = parse_http_list(authorization_header[len(auth_scheme):]) + items = parse_http_list(authorization_header[len(auth_scheme) :]) try: items = parse_keqv_list(items).items() auth_params = [(unescape(k), unescape(v)) for k, v in items] - realm = dict(auth_params).get('realm') + realm = dict(auth_params).get("realm") return auth_params, realm except (IndexError, ValueError): pass - raise ValueError('Malformed authorization header') + raise ValueError("Malformed authorization header") def _parse_oauth_params(query_params, body_params, auth_params): oauth_params_set = [ (SIGNATURE_TYPE_QUERY, list(_filter_oauth(query_params))), (SIGNATURE_TYPE_BODY, list(_filter_oauth(body_params))), - (SIGNATURE_TYPE_HEADER, list(_filter_oauth(auth_params))) + (SIGNATURE_TYPE_HEADER, list(_filter_oauth(auth_params))), ] oauth_params_set = [params for params in oauth_params_set if params[1]] if len(oauth_params_set) > 1: found_types = [p[0] for p in oauth_params_set] raise DuplicatedOAuthProtocolParameterError( '"oauth_" params must come from only 1 signature type ' - 'but were found in {}'.format(','.join(found_types)) + "but were found in {}".format(",".join(found_types)) ) if oauth_params_set: diff --git a/authlib/oauth2/__init__.py b/authlib/oauth2/__init__.py index 23dea91b9..76bb873c2 100644 --- a/authlib/oauth2/__init__.py +++ b/authlib/oauth2/__init__.py @@ -1,16 +1,21 @@ +from .auth import ClientAuth +from .auth import TokenAuth from .base import OAuth2Error -from .auth import ClientAuth, TokenAuth from .client import OAuth2Client -from .rfc6749 import ( - OAuth2Request, - HttpRequest, - AuthorizationServer, - ClientAuthentication, - ResourceProtector, -) +from .rfc6749 import AuthorizationServer +from .rfc6749 import ClientAuthentication +from .rfc6749 import JsonRequest +from .rfc6749 import OAuth2Request +from .rfc6749 import ResourceProtector __all__ = [ - 'OAuth2Error', 'ClientAuth', 'TokenAuth', 'OAuth2Client', - 'OAuth2Request', 'HttpRequest', 'AuthorizationServer', - 'ClientAuthentication', 'ResourceProtector', + "OAuth2Error", + "ClientAuth", + "TokenAuth", + "OAuth2Client", + "OAuth2Request", + "JsonRequest", + "AuthorizationServer", + "ClientAuthentication", + "ResourceProtector", ] diff --git a/authlib/oauth2/auth.py b/authlib/oauth2/auth.py index 1d7a655a8..dffccb7f2 100644 --- a/authlib/oauth2/auth.py +++ b/authlib/oauth2/auth.py @@ -1,38 +1,45 @@ import base64 -from authlib.common.urls import add_params_to_qs, add_params_to_uri -from authlib.common.encoding import to_bytes, to_native + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_native +from authlib.common.urls import add_params_to_qs +from authlib.common.urls import add_params_to_uri + from .rfc6749 import OAuth2Token from .rfc6750 import add_bearer_token def encode_client_secret_basic(client, method, uri, headers, body): - text = '{}:{}'.format(client.client_id, client.client_secret) - auth = to_native(base64.urlsafe_b64encode(to_bytes(text, 'latin1'))) - headers['Authorization'] = 'Basic {}'.format(auth) + text = f"{client.client_id}:{client.client_secret}" + auth = to_native(base64.b64encode(to_bytes(text, "latin1"))) + headers["Authorization"] = f"Basic {auth}" return uri, headers, body def encode_client_secret_post(client, method, uri, headers, body): - body = add_params_to_qs(body or '', [ - ('client_id', client.client_id), - ('client_secret', client.client_secret or '') - ]) - if 'Content-Length' in headers: - headers['Content-Length'] = str(len(body)) + body = add_params_to_qs( + body or "", + [ + ("client_id", client.client_id), + ("client_secret", client.client_secret or ""), + ], + ) + if "Content-Length" in headers: + headers["Content-Length"] = str(len(body)) return uri, headers, body def encode_none(client, method, uri, headers, body): - if method == 'GET': - uri = add_params_to_uri(uri, [('client_id', client.client_id)]) + if method == "GET": + uri = add_params_to_uri(uri, [("client_id", client.client_id)]) return uri, headers, body - body = add_params_to_qs(body, [('client_id', client.client_id)]) - if 'Content-Length' in headers: - headers['Content-Length'] = str(len(body)) + body = add_params_to_qs(body, [("client_id", client.client_id)]) + if "Content-Length" in headers: + headers["Content-Length"] = str(len(body)) return uri, headers, body -class ClientAuth(object): +class ClientAuth: """Attaches OAuth Client Information to HTTP requests. :param client_id: Client ID, which you get from client registration. @@ -44,15 +51,16 @@ class ClientAuth(object): * client_secret_post * none """ + DEFAULT_AUTH_METHODS = { - 'client_secret_basic': encode_client_secret_basic, - 'client_secret_post': encode_client_secret_post, - 'none': encode_none, + "client_secret_basic": encode_client_secret_basic, + "client_secret_post": encode_client_secret_post, + "none": encode_none, } def __init__(self, client_id, client_secret, auth_method=None): if auth_method is None: - auth_method = 'client_secret_basic' + auth_method = "client_secret_basic" self.client_id = client_id self.client_secret = client_secret @@ -66,7 +74,7 @@ def prepare(self, method, uri, headers, body): return self.auth_method(self, method, uri, headers, body) -class TokenAuth(object): +class TokenAuth: """Attach token information to HTTP requests. :param token: A dict or OAuth2Token instance of an OAuth 2.0 token @@ -77,12 +85,11 @@ class TokenAuth(object): * body * uri """ - DEFAULT_TOKEN_TYPE = 'bearer' - SIGN_METHODS = { - 'bearer': add_bearer_token - } - def __init__(self, token, token_placement='header', client=None): + DEFAULT_TOKEN_TYPE = "bearer" + SIGN_METHODS = {"bearer": add_bearer_token} + + def __init__(self, token, token_placement="header", client=None): self.token = OAuth2Token.from_dict(token) self.token_placement = token_placement self.client = client @@ -92,14 +99,17 @@ def set_token(self, token): self.token = OAuth2Token.from_dict(token) def prepare(self, uri, headers, body): - token_type = self.token.get('token_type', self.DEFAULT_TOKEN_TYPE) + token_type = self.token.get("token_type", self.DEFAULT_TOKEN_TYPE) sign = self.SIGN_METHODS[token_type.lower()] uri, headers, body = sign( - self.token['access_token'], - uri, headers, body, - self.token_placement) + self.token["access_token"], uri, headers, body, self.token_placement + ) for hook in self.hooks: uri, headers, body = hook(uri, headers, body) return uri, headers, body + + def __del__(self): + del self.client + del self.hooks diff --git a/authlib/oauth2/base.py b/authlib/oauth2/base.py index 5fea8e084..407c0935d 100644 --- a/authlib/oauth2/base.py +++ b/authlib/oauth2/base.py @@ -2,26 +2,61 @@ from authlib.common.urls import add_params_to_uri +def invalid_error_characters(text: str) -> list[str]: + """Check whether the string only contains characters from the restricted ASCII set defined in RFC6749 for errors. + + https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 + """ + valid_ranges = [ + (0x20, 0x21), + (0x23, 0x5B), + (0x5D, 0x7E), + ] + + return [ + char + for char in set(text) + if not any(start <= ord(char) <= end for start, end in valid_ranges) + ] + + class OAuth2Error(AuthlibHTTPError): - def __init__(self, description=None, uri=None, - status_code=None, state=None, - redirect_uri=None, redirect_fragment=False, error=None): - super(OAuth2Error, self).__init__(error, description, uri, status_code) + def __init__( + self, + description=None, + uri=None, + status_code=None, + state=None, + redirect_uri=None, + redirect_fragment=False, + error=None, + ): + # Human-readable ASCII [USASCII] text providing + # additional information, used to assist the client developer in + # understanding the error that occurred. + # Values for the "error_description" parameter MUST NOT include + # characters outside the set %x20-21 / %x23-5B / %x5D-7E. + if description: + if chars := invalid_error_characters(description): + raise ValueError( + f"Error description contains forbidden characters: {', '.join(chars)}." + ) + + super().__init__(error, description, uri, status_code) self.state = state self.redirect_uri = redirect_uri self.redirect_fragment = redirect_fragment def get_body(self): """Get a list of body.""" - error = super(OAuth2Error, self).get_body() + error = super().get_body() if self.state: - error.append(('state', self.state)) + error.append(("state", self.state)) return error - def __call__(self, translations=None, error_uris=None): + def __call__(self, uri=None): if self.redirect_uri: params = self.get_body() - loc = add_params_to_uri( - self.redirect_uri, params, self.redirect_fragment) - return 302, '', [('Location', loc)] - return super(OAuth2Error, self).__call__(translations, error_uris) + loc = add_params_to_uri(self.redirect_uri, params, self.redirect_fragment) + return 302, "", [("Location", loc)] + return super().__call__(uri=uri) diff --git a/authlib/oauth2/claims.py b/authlib/oauth2/claims.py new file mode 100644 index 000000000..3b528b80e --- /dev/null +++ b/authlib/oauth2/claims.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any +from typing import TypedDict + +from joserfc.errors import InvalidClaimError +from joserfc.jwt import BaseClaimsRegistry +from joserfc.jwt import Claims +from joserfc.jwt import JWTClaimsRegistry +from joserfc.registry import Header + + +class ClaimsOption(TypedDict, total=False): + essential: bool + allow_blank: bool | None + value: str | int | bool + values: list[str | int | bool] | list[str] | list[int] | list[bool] + validate: Callable[[BaseClaims, Any], bool] + + +class BaseClaims(dict): + registry_cls = BaseClaimsRegistry + REGISTERED_CLAIMS = [] + + def __init__( + self, + claims: Claims, + header: Header, + options: dict[str, ClaimsOption] | None = None, + params: dict[str, Any] = None, + ): + super().__init__(claims) + self._validate_hooks = {} + self.header = header + if options: + self._extract_validate_hooks(options) + self.options = options or {} + self.params = params or {} + + def _extract_validate_hooks(self, options: dict[str, ClaimsOption]): + for key in options: + validate = options[key].pop("validate", None) + if validate: + self._validate_hooks[key] = validate + + def _run_validate_hooks(self): + for key in self._validate_hooks: + validate = self._validate_hooks[key] + if validate and key in self and not validate(self, self[key]): + raise InvalidClaimError(key) + + def get_registered_claims(self): + rv = {} + for k in self.REGISTERED_CLAIMS: + if k in self: + rv[k] = self[k] + return rv + + def validate(self, now=None, leeway=0): + validator = self.registry_cls(**self.options) + validator.validate(self) + self._run_validate_hooks() + + +class JWTClaims(BaseClaims): + registry_cls = JWTClaimsRegistry + REGISTERED_CLAIMS = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"] + + def validate(self, now=None, leeway=0): + if self.options: + validator = self.registry_cls(now, leeway, **self.options) + else: + validator = self.registry_cls(now, leeway) + validator.validate(self) + self._run_validate_hooks() diff --git a/authlib/oauth2/client.py b/authlib/oauth2/client.py index 2720d4643..340c11bbe 100644 --- a/authlib/oauth2/client.py +++ b/authlib/oauth2/client.py @@ -1,23 +1,23 @@ from authlib.common.security import generate_token from authlib.common.urls import url_decode -from authlib.common.encoding import text_types -from .rfc6749.parameters import ( - prepare_grant_uri, - prepare_token_request, - parse_authorization_code_response, - parse_implicit_response, -) + +from .auth import ClientAuth +from .auth import TokenAuth +from .base import OAuth2Error +from .rfc6749.parameters import parse_authorization_code_response +from .rfc6749.parameters import parse_implicit_response +from .rfc6749.parameters import prepare_grant_uri +from .rfc6749.parameters import prepare_token_request from .rfc7009 import prepare_revoke_token_request from .rfc7636 import create_s256_code_challenge -from .auth import TokenAuth, ClientAuth DEFAULT_HEADERS = { - 'Accept': 'application/json', - 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", } -class OAuth2Client(object): +class OAuth2Client: """Construct a new OAuth 2 protocol client. :param session: Requests session object to communicate with @@ -29,6 +29,7 @@ class OAuth2Client(object): :param revocation_endpoint_auth_method: client authentication method for revocation endpoint. :param scope: Scope that you needed to access user resources. + :param state: Shared secret to prevent CSRF attack. :param redirect_uri: Redirect URI you registered as callback. :param code_challenge_method: PKCE method name, only S256 is supported. :param token: A dict of token attributes such as ``access_token``, @@ -37,38 +38,53 @@ class OAuth2Client(object): values: "header", "body", "uri". :param update_token: A function for you to update token. It accept a :class:`OAuth2Token` as parameter. + :param leeway: Time window in seconds before the actual expiration of the + authentication token, that the token is considered expired and will + be refreshed. """ + client_auth_class = ClientAuth token_auth_class = TokenAuth + oauth_error_class = OAuth2Error - EXTRA_AUTHORIZE_PARAMS = ( - 'response_mode', 'nonce', 'prompt', 'login_hint' - ) + EXTRA_AUTHORIZE_PARAMS = ("response_mode", "nonce", "prompt", "login_hint") SESSION_REQUEST_PARAMS = [] - def __init__(self, session, client_id=None, client_secret=None, - token_endpoint_auth_method=None, - revocation_endpoint_auth_method=None, - scope=None, redirect_uri=None, code_challenge_method=None, - token=None, token_placement='header', update_token=None, **metadata): - + def __init__( + self, + session, + client_id=None, + client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, + state=None, + redirect_uri=None, + code_challenge_method=None, + token=None, + token_placement="header", + update_token=None, + leeway=60, + **metadata, + ): self.session = session self.client_id = client_id self.client_secret = client_secret + self.state = state if token_endpoint_auth_method is None: if client_secret: - token_endpoint_auth_method = 'client_secret_basic' + token_endpoint_auth_method = "client_secret_basic" else: - token_endpoint_auth_method = 'none' + token_endpoint_auth_method = "none" self.token_endpoint_auth_method = token_endpoint_auth_method if revocation_endpoint_auth_method is None: if client_secret: - revocation_endpoint_auth_method = 'client_secret_basic' + revocation_endpoint_auth_method = "client_secret_basic" else: - revocation_endpoint_auth_method = 'none' + revocation_endpoint_auth_method = "none" self.revocation_endpoint_auth_method = revocation_endpoint_auth_method @@ -79,21 +95,25 @@ def __init__(self, session, client_id=None, client_secret=None, self.token_auth = self.token_auth_class(token, token_placement, self) self.update_token = update_token - token_updater = metadata.pop('token_updater', None) + token_updater = metadata.pop("token_updater", None) if token_updater: - raise ValueError('update token has been redesigned, checkout the documentation') + raise ValueError( + "update token has been redesigned, checkout the documentation" + ) self.metadata = metadata self.compliance_hook = { - 'access_token_response': set(), - 'refresh_token_request': set(), - 'refresh_token_response': set(), - 'revoke_token_request': set(), - 'introspect_token_request': set(), + "access_token_response": set(), + "refresh_token_request": set(), + "refresh_token_response": set(), + "revoke_token_request": set(), + "introspect_token_request": set(), } self._auth_methods = {} + self.leeway = leeway + def register_client_auth_method(self, auth): """Extend client authenticate for token endpoint. @@ -105,7 +125,7 @@ def register_client_auth_method(self, auth): self._auth_methods[auth.name] = auth def client_auth(self, auth_method): - if isinstance(auth_method, text_types) and auth_method in self._auth_methods: + if isinstance(auth_method, str) and auth_method in self._auth_methods: auth_method = self._auth_methods[auth_method] return self.client_auth_class( client_id=self.client_id, @@ -134,28 +154,45 @@ def create_authorization_url(self, url, state=None, code_verifier=None, **kwargs if state is None: state = generate_token() - response_type = self.metadata.get('response_type', 'code') - response_type = kwargs.pop('response_type', response_type) - if 'redirect_uri' not in kwargs: - kwargs['redirect_uri'] = self.redirect_uri - if 'scope' not in kwargs: - kwargs['scope'] = self.scope - - if code_verifier and response_type == 'code' and self.code_challenge_method == 'S256': - kwargs['code_challenge'] = create_s256_code_challenge(code_verifier) - kwargs['code_challenge_method'] = self.code_challenge_method + response_type = self.metadata.get("response_type", "code") + response_type = kwargs.pop("response_type", response_type) + if "redirect_uri" not in kwargs: + kwargs["redirect_uri"] = self.redirect_uri + if "scope" not in kwargs: + kwargs["scope"] = self.scope + + if ( + code_verifier + and response_type == "code" + and self.code_challenge_method == "S256" + ): + kwargs["code_challenge"] = create_s256_code_challenge(code_verifier) + kwargs["code_challenge_method"] = self.code_challenge_method for k in self.EXTRA_AUTHORIZE_PARAMS: if k not in kwargs and k in self.metadata: kwargs[k] = self.metadata[k] uri = prepare_grant_uri( - url, client_id=self.client_id, response_type=response_type, - state=state, **kwargs) + url, + client_id=self.client_id, + response_type=response_type, + state=state, + **kwargs, + ) return uri, state - def fetch_token(self, url=None, body='', method='POST', headers=None, - auth=None, grant_type=None, **kwargs): + def fetch_token( + self, + url=None, + body="", + method="POST", + headers=None, + auth=None, + grant_type=None, + state=None, + **kwargs, + ): """Generic method for fetching an access token from the token endpoint. :param url: Access Token endpoint URL, if not configured, @@ -168,26 +205,32 @@ def fetch_token(self, url=None, body='', method='POST', headers=None, be added as needed. :param headers: Dict to default request headers with. :param auth: An auth tuple or method as accepted by requests. - :param grant_type: Use specified grant_type to fetch token + :param grant_type: Use specified grant_type to fetch token. + :param state: Optional "state" value to fetch token. :return: A :class:`OAuth2Token` object (a dict too). """ + state = state or self.state # implicit grant_type - authorization_response = kwargs.pop('authorization_response', None) - if authorization_response and '#' in authorization_response: - return self.token_from_fragment(authorization_response, kwargs.get('state')) + authorization_response = kwargs.pop("authorization_response", None) + if authorization_response and "#" in authorization_response: + return self.token_from_fragment(authorization_response, state) session_kwargs = self._extract_session_request_params(kwargs) - if authorization_response and 'code=' in authorization_response: - grant_type = 'authorization_code' + if authorization_response and "code=" in authorization_response: + grant_type = "authorization_code" params = parse_authorization_code_response( authorization_response, - state=kwargs.get('state'), + state=state, ) - kwargs['code'] = params['code'] + kwargs["code"] = params["code"] + + if grant_type is None: + grant_type = self.metadata.get("grant_type") if grant_type is None: - grant_type = self.metadata.get('grant_type') + grant_type = _guess_grant_type(kwargs) + self.metadata["grant_type"] = grant_type body = self._prepare_token_endpoint_body(body, grant_type, **kwargs) @@ -198,39 +241,24 @@ def fetch_token(self, url=None, body='', method='POST', headers=None, headers = DEFAULT_HEADERS if url is None: - url = self.metadata.get('token_endpoint') + url = self.metadata.get("token_endpoint") return self._fetch_token( - url, body=body, auth=auth, method=method, - headers=headers, **session_kwargs + url, body=body, auth=auth, method=method, headers=headers, **session_kwargs ) - def _fetch_token(self, url, body='', headers=None, auth=None, - method='POST', **kwargs): - if method == 'GET': - if '?' in url: - url = '&'.join([url, body]) - else: - url = '?'.join([url, body]) - body = '' - - if headers is None: - headers = DEFAULT_HEADERS - - resp = self.session.request( - method, url, data=body, headers=headers, auth=auth, **kwargs) - - for hook in self.compliance_hook['access_token_response']: - resp = hook(resp) - - return self.parse_response_token(resp.json()) - def token_from_fragment(self, authorization_response, state=None): token = parse_implicit_response(authorization_response, state) - return self.parse_response_token(token) + if "error" in token: + raise self.oauth_error_class( + error=token["error"], description=token.get("error_description") + ) + self.token = token + return token - def refresh_token(self, url, refresh_token=None, body='', - auth=None, headers=None, **kwargs): + def refresh_token( + self, url=None, refresh_token=None, body="", auth=None, headers=None, **kwargs + ): """Fetch a new access token using a refresh token. :param url: Refresh Token endpoint, must be HTTPS. @@ -242,47 +270,61 @@ def refresh_token(self, url, refresh_token=None, body='', :return: A :class:`OAuth2Token` object (a dict too). """ session_kwargs = self._extract_session_request_params(kwargs) - refresh_token = refresh_token or self.token.get('refresh_token') - if 'scope' not in kwargs and self.scope: - kwargs['scope'] = self.scope + refresh_token = refresh_token or self.token.get("refresh_token") + if "scope" not in kwargs and self.scope: + kwargs["scope"] = self.scope body = prepare_token_request( - 'refresh_token', body, - refresh_token=refresh_token, **kwargs + "refresh_token", body, refresh_token=refresh_token, **kwargs ) if headers is None: - headers = DEFAULT_HEADERS + headers = DEFAULT_HEADERS.copy() + + if url is None: + url = self.metadata.get("token_endpoint") - for hook in self.compliance_hook['refresh_token_request']: + for hook in self.compliance_hook["refresh_token_request"]: url, headers, body = hook(url, headers, body) if auth is None: auth = self.client_auth(self.token_endpoint_auth_method) return self._refresh_token( - url, refresh_token=refresh_token, body=body, headers=headers, - auth=auth, **session_kwargs) - - def _refresh_token(self, url, refresh_token=None, body='', headers=None, - auth=None, **kwargs): - resp = self.session.post( - url, data=dict(url_decode(body)), headers=headers, - auth=auth, **kwargs) - - for hook in self.compliance_hook['refresh_token_response']: - resp = hook(resp) - - token = self.parse_response_token(resp.json()) - if 'refresh_token' not in token: - self.token['refresh_token'] = refresh_token - - if callable(self.update_token): - self.update_token(self.token, refresh_token=refresh_token) - - return self.token + url, + refresh_token=refresh_token, + body=body, + headers=headers, + auth=auth, + **session_kwargs, + ) - def revoke_token(self, url, token=None, token_type_hint=None, - body=None, auth=None, headers=None, **kwargs): + def ensure_active_token(self, token=None): + if token is None: + token = self.token + if not token.is_expired(leeway=self.leeway): + return True + refresh_token = token.get("refresh_token") + url = self.metadata.get("token_endpoint") + if refresh_token and url: + self.refresh_token(url, refresh_token=refresh_token) + return True + elif self.metadata.get("grant_type") == "client_credentials": + access_token = token["access_token"] + new_token = self.fetch_token(url, grant_type="client_credentials") + if self.update_token: + self.update_token(new_token, access_token=access_token) + return True + + def revoke_token( + self, + url, + token=None, + token_type_hint=None, + body=None, + auth=None, + headers=None, + **kwargs, + ): """Revoke token method defined via `RFC7009`_. :param url: Revoke Token endpoint, must be HTTPS. @@ -297,13 +339,29 @@ def revoke_token(self, url, token=None, token_type_hint=None, .. _`RFC7009`: https://tools.ietf.org/html/rfc7009 """ + if auth is None: + auth = self.client_auth(self.revocation_endpoint_auth_method) return self._handle_token_hint( - 'revoke_token_request', url, - token=token, token_type_hint=token_type_hint, - body=body, auth=auth, headers=headers, **kwargs) + "revoke_token_request", + url, + token=token, + token_type_hint=token_type_hint, + body=body, + auth=auth, + headers=headers, + **kwargs, + ) - def introspect_token(self, url, token=None, token_type_hint=None, - body=None, auth=None, headers=None, **kwargs): + def introspect_token( + self, + url, + token=None, + token_type_hint=None, + body=None, + auth=None, + headers=None, + **kwargs, + ): """Implementation of OAuth 2.0 Token Introspection defined via `RFC7662`_. :param url: Introspection Endpoint, must be HTTPS. @@ -318,36 +376,18 @@ def introspect_token(self, url, token=None, token_type_hint=None, .. _`RFC7662`: https://tools.ietf.org/html/rfc7662 """ - return self._handle_token_hint( - 'introspect_token_request', url, - token=token, token_type_hint=token_type_hint, - body=body, auth=auth, headers=headers, **kwargs) - - def _handle_token_hint(self, hook, url, token=None, token_type_hint=None, - body=None, auth=None, headers=None, **kwargs): - if token is None and self.token: - token = self.token.get('refresh_token') or self.token.get('access_token') - - if body is None: - body = '' - - body, headers = prepare_revoke_token_request( - token, token_type_hint, body, headers) - - for hook in self.compliance_hook[hook]: - url, headers, body = hook(url, headers, body) - if auth is None: - auth = self.client_auth(self.revocation_endpoint_auth_method) - - session_kwargs = self._extract_session_request_params(kwargs) - return self._http_post( - url, body, auth=auth, headers=headers, **session_kwargs) - - def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): - return self.session.post( - url, data=dict(url_decode(body)), - headers=headers, auth=auth, **kwargs) + auth = self.client_auth(self.token_endpoint_auth_method) + return self._handle_token_hint( + "introspect_token_request", + url, + token=token, + token_type_hint=token_type_hint, + body=body, + auth=auth, + headers=headers, + **kwargs, + ) def register_compliance_hook(self, hook_type, hook): """Register a hook for request/response tweaking. @@ -361,35 +401,104 @@ def register_compliance_hook(self, hook_type, hook): * revoke_token_request: invoked before revoking a token. * introspect_token_request: invoked before introspecting a token. """ - if hook_type == 'protected_request': + if hook_type == "protected_request": self.token_auth.hooks.add(hook) return if hook_type not in self.compliance_hook: - raise ValueError('Hook type %s is not in %s.', - hook_type, self.compliance_hook) + raise ValueError( + "Hook type %s is not in %s.", hook_type, self.compliance_hook + ) self.compliance_hook[hook_type].add(hook) - def parse_response_token(self, token): - if 'error' not in token: - self.token = token - return self.token + def parse_response_token(self, resp): + if resp.status_code >= 500: + resp.raise_for_status() - error = token['error'] - description = token.get('error_description', error) - self.handle_error(error, description) + token = resp.json() + if "error" in token: + raise self.oauth_error_class( + error=token["error"], description=token.get("error_description") + ) + self.token = token + return self.token - def _prepare_token_endpoint_body(self, body, grant_type, **kwargs): - if grant_type is None: - grant_type = _guess_grant_type(kwargs) + def _fetch_token( + self, url, body="", headers=None, auth=None, method="POST", **kwargs + ): + if method.upper() == "POST": + resp = self.session.post( + url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs + ) + else: + if "?" in url: + url = "&".join([url, body]) + else: + url = "?".join([url, body]) + resp = self.session.request( + method, url, headers=headers, auth=auth, **kwargs + ) + + for hook in self.compliance_hook["access_token_response"]: + resp = hook(resp) + + return self.parse_response_token(resp) + + def _refresh_token( + self, url, refresh_token=None, body="", headers=None, auth=None, **kwargs + ): + resp = self._http_post(url, body=body, auth=auth, headers=headers, **kwargs) + + for hook in self.compliance_hook["refresh_token_response"]: + resp = hook(resp) - if grant_type == 'authorization_code': - if 'redirect_uri' not in kwargs: - kwargs['redirect_uri'] = self.redirect_uri + token = self.parse_response_token(resp) + if "refresh_token" not in token: + self.token["refresh_token"] = refresh_token + + if callable(self.update_token): + self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def _handle_token_hint( + self, + hook, + url, + token=None, + token_type_hint=None, + body=None, + auth=None, + headers=None, + **kwargs, + ): + if token is None and self.token: + token = self.token.get("refresh_token") or self.token.get("access_token") + + if body is None: + body = "" + + body, headers = prepare_revoke_token_request( + token, token_type_hint, body, headers + ) + + for compliance_hook in self.compliance_hook[hook]: + url, headers, body = compliance_hook(url, headers, body) + + if auth is None: + auth = self.client_auth(self.revocation_endpoint_auth_method) + + session_kwargs = self._extract_session_request_params(kwargs) + return self._http_post(url, body, auth=auth, headers=headers, **session_kwargs) + + def _prepare_token_endpoint_body(self, body, grant_type, **kwargs): + if grant_type == "authorization_code": + if "redirect_uri" not in kwargs: + kwargs["redirect_uri"] = self.redirect_uri return prepare_token_request(grant_type, body, **kwargs) - if 'scope' not in kwargs and self.scope: - kwargs['scope'] = self.scope + if "scope" not in kwargs and self.scope: + kwargs["scope"] = self.scope return prepare_token_request(grant_type, body, **kwargs) def _extract_session_request_params(self, kwargs): @@ -400,16 +509,20 @@ def _extract_session_request_params(self, kwargs): rv[k] = kwargs.pop(k) return rv - @staticmethod - def handle_error(error_type, error_description): - raise ValueError('{}: {}'.format(error_type, error_description)) + def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): + return self.session.post( + url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs + ) + + def __del__(self): + del self.session def _guess_grant_type(kwargs): - if 'code' in kwargs: - grant_type = 'authorization_code' - elif 'username' in kwargs and 'password' in kwargs: - grant_type = 'password' + if "code" in kwargs: + grant_type = "authorization_code" + elif "username" in kwargs and "password" in kwargs: + grant_type = "password" else: - grant_type = 'client_credentials' + grant_type = "client_credentials" return grant_type diff --git a/authlib/oauth2/rfc6749/__init__.py b/authlib/oauth2/rfc6749/__init__.py index 2994f4f4e..7acd4fabb 100644 --- a/authlib/oauth2/rfc6749/__init__.py +++ b/authlib/oauth2/rfc6749/__init__.py @@ -1,77 +1,94 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc6749 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6749. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - The OAuth 2.0 Authorization Framework. +This module represents a direct implementation of +The OAuth 2.0 Authorization Framework. - https://tools.ietf.org/html/rfc6749 +https://tools.ietf.org/html/rfc6749 """ -from .wrappers import OAuth2Request, OAuth2Token, HttpRequest -from .errors import ( - OAuth2Error, - AccessDeniedError, - MissingAuthorizationError, - InvalidGrantError, - InvalidClientError, - InvalidRequestError, - InvalidScopeError, - InsecureTransportError, - UnauthorizedClientError, - UnsupportedGrantTypeError, - UnsupportedTokenTypeError, - # exceptions for clients - MissingCodeException, - MissingTokenException, - MissingTokenTypeException, - MismatchingStateException, -) -from .models import ClientMixin, AuthorizationCodeMixin, TokenMixin from .authenticate_client import ClientAuthentication from .authorization_server import AuthorizationServer +from .endpoint import Endpoint +from .endpoint import EndpointRequest +from .errors import AccessDeniedError +from .errors import InsecureTransportError +from .errors import InvalidClientError +from .errors import InvalidGrantError +from .errors import InvalidRequestError +from .errors import InvalidScopeError +from .errors import MismatchingStateException +from .errors import MissingAuthorizationError +from .errors import MissingCodeException # exceptions for clients +from .errors import MissingTokenException +from .errors import MissingTokenTypeException +from .errors import OAuth2Error +from .errors import UnauthorizedClientError +from .errors import UnsupportedGrantTypeError +from .errors import UnsupportedResponseTypeError +from .errors import UnsupportedTokenTypeError +from .grants import AuthorizationCodeGrant +from .grants import AuthorizationEndpointMixin +from .grants import BaseGrant +from .grants import ClientCredentialsGrant +from .grants import ImplicitGrant +from .grants import RefreshTokenGrant +from .grants import ResourceOwnerPasswordCredentialsGrant +from .grants import TokenEndpointMixin +from .models import AuthorizationCodeMixin +from .models import ClientMixin +from .models import TokenMixin +from .requests import JsonPayload +from .requests import JsonRequest +from .requests import OAuth2Payload +from .requests import OAuth2Request from .resource_protector import ResourceProtector +from .resource_protector import TokenValidator from .token_endpoint import TokenEndpoint -from .grants import ( - BaseGrant, - AuthorizationEndpointMixin, - TokenEndpointMixin, - AuthorizationCodeGrant, - ImplicitGrant, - ResourceOwnerPasswordCredentialsGrant, - ClientCredentialsGrant, - RefreshTokenGrant, -) +from .util import list_to_scope +from .util import scope_to_list +from .wrappers import OAuth2Token __all__ = [ - 'OAuth2Request', 'OAuth2Token', 'HttpRequest', - 'OAuth2Error', - 'AccessDeniedError', - 'MissingAuthorizationError', - 'InvalidGrantError', - 'InvalidClientError', - 'InvalidRequestError', - 'InvalidScopeError', - 'InsecureTransportError', - 'UnauthorizedClientError', - 'UnsupportedGrantTypeError', - 'UnsupportedTokenTypeError', - 'MissingCodeException', - 'MissingTokenException', - 'MissingTokenTypeException', - 'MismatchingStateException', - 'ClientMixin', 'AuthorizationCodeMixin', 'TokenMixin', - 'ClientAuthentication', - 'AuthorizationServer', - 'ResourceProtector', - 'TokenEndpoint', - 'BaseGrant', - 'AuthorizationEndpointMixin', - 'TokenEndpointMixin', - 'AuthorizationCodeGrant', - 'ImplicitGrant', - 'ResourceOwnerPasswordCredentialsGrant', - 'ClientCredentialsGrant', - 'RefreshTokenGrant', + "OAuth2Payload", + "OAuth2Token", + "OAuth2Request", + "JsonPayload", + "JsonRequest", + "OAuth2Error", + "AccessDeniedError", + "MissingAuthorizationError", + "InvalidGrantError", + "InvalidClientError", + "InvalidRequestError", + "InvalidScopeError", + "InsecureTransportError", + "UnauthorizedClientError", + "UnsupportedResponseTypeError", + "UnsupportedGrantTypeError", + "UnsupportedTokenTypeError", + "MissingCodeException", + "MissingTokenException", + "MissingTokenTypeException", + "MismatchingStateException", + "ClientMixin", + "AuthorizationCodeMixin", + "TokenMixin", + "ClientAuthentication", + "AuthorizationServer", + "ResourceProtector", + "TokenValidator", + "Endpoint", + "EndpointRequest", + "TokenEndpoint", + "BaseGrant", + "AuthorizationEndpointMixin", + "TokenEndpointMixin", + "AuthorizationCodeGrant", + "ImplicitGrant", + "ResourceOwnerPasswordCredentialsGrant", + "ClientCredentialsGrant", + "RefreshTokenGrant", + "scope_to_list", + "list_to_scope", ] diff --git a/authlib/oauth2/rfc6749/authenticate_client.py b/authlib/oauth2/rfc6749/authenticate_client.py index d21289a1f..3792dcabf 100644 --- a/authlib/oauth2/rfc6749/authenticate_client.py +++ b/authlib/oauth2/rfc6749/authenticate_client.py @@ -1,55 +1,60 @@ -""" - authlib.oauth2.rfc6749.authenticate_client - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6749.authenticate_client. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Registry of client authentication methods, with 3 built-in methods: +Registry of client authentication methods, with 3 built-in methods: - 1. client_secret_basic - 2. client_secret_post - 3. none +1. client_secret_basic +2. client_secret_post +3. none - The "client_secret_basic" method is used a lot in examples of `RFC6749`_, - but the concept of naming are introduced in `RFC7591`_. +The "client_secret_basic" method is used a lot in examples of `RFC6749`_, +but the concept of naming are introduced in `RFC7591`_. - .. _`RFC6749`: https://tools.ietf.org/html/rfc6749 - .. _`RFC7591`: https://tools.ietf.org/html/rfc7591 +.. _`RFC6749`: https://tools.ietf.org/html/rfc6749 +.. _`RFC7591`: https://tools.ietf.org/html/rfc7591 """ import logging + from .errors import InvalidClientError from .util import extract_basic_authorization log = logging.getLogger(__name__) -__all__ = ['ClientAuthentication'] +__all__ = ["ClientAuthentication"] -class ClientAuthentication(object): +class ClientAuthentication: def __init__(self, query_client): self.query_client = query_client self._methods = { - 'none': authenticate_none, - 'client_secret_basic': authenticate_client_secret_basic, - 'client_secret_post': authenticate_client_secret_post, + "none": authenticate_none, + "client_secret_basic": authenticate_client_secret_basic, + "client_secret_post": authenticate_client_secret_post, } def register(self, method, func): self._methods[method] = func - def authenticate(self, request, methods): + def authenticate(self, request, methods, endpoint): for method in methods: func = self._methods[method] client = func(self.query_client, request) - if client: + if client and client.check_endpoint_auth_method(method, endpoint): request.auth_method = method return client - if 'client_secret_basic' in methods: - raise InvalidClientError(state=request.state, status_code=401) - raise InvalidClientError(state=request.state) + if "client_secret_basic" in methods: + raise InvalidClientError( + status_code=401, + description=f"The client cannot authenticate with methods: {methods}", + ) + raise InvalidClientError( + description=f"The client cannot authenticate with methods: {methods}", + ) - def __call__(self, request, methods): - return self.authenticate(request, methods) + def __call__(self, request, methods, endpoint="token"): + return self.authenticate(request, methods, endpoint) def authenticate_client_secret_basic(query_client, request): @@ -58,18 +63,11 @@ def authenticate_client_secret_basic(query_client, request): """ client_id, client_secret = extract_basic_authorization(request.headers) if client_id and client_secret: - client = _validate_client(query_client, client_id, request.state, 401) - if client.check_token_endpoint_auth_method('client_secret_basic') \ - and client.check_client_secret(client_secret): - log.debug( - 'Authenticate %s via "client_secret_basic" ' - 'success', client_id - ) + client = _validate_client(query_client, client_id, 401) + if client.check_client_secret(client_secret): + log.debug(f'Authenticate {client_id} via "client_secret_basic" success') return client - log.debug( - 'Authenticate %s via "client_secret_basic" ' - 'failed', client_id - ) + log.debug(f'Authenticate {client_id} via "client_secret_basic" failed') def authenticate_client_secret_post(query_client, request): @@ -77,48 +75,40 @@ def authenticate_client_secret_post(query_client, request): uses POST parameters for authentication. """ data = request.form - client_id = data.get('client_id') - client_secret = data.get('client_secret') + client_id = data.get("client_id") + client_secret = data.get("client_secret") if client_id and client_secret: - client = _validate_client(query_client, client_id, request.state) - if client.check_token_endpoint_auth_method('client_secret_post') \ - and client.check_client_secret(client_secret): - log.debug( - 'Authenticate %s via "client_secret_post" ' - 'success', client_id - ) + client = _validate_client(query_client, client_id) + if client.check_client_secret(client_secret): + log.debug(f'Authenticate {client_id} via "client_secret_post" success') return client - log.debug( - 'Authenticate %s via "client_secret_post" ' - 'failed', client_id - ) + log.debug(f'Authenticate {client_id} via "client_secret_post" failed') def authenticate_none(query_client, request): """Authenticate public client by ``none`` method. The client does not have a client secret. """ - client_id = request.client_id - if client_id and 'client_secret' not in request.data: - client = _validate_client(query_client, client_id, request.state) - if client.check_token_endpoint_auth_method('none'): - log.debug( - 'Authenticate %s via "none" ' - 'success', client_id - ) - return client - log.debug( - 'Authenticate {} via "none" ' - 'failed'.format(client_id) - ) + client_id = request.payload.client_id + if client_id and not request.payload.data.get("client_secret"): + client = _validate_client(query_client, client_id) + log.debug(f'Authenticate {client_id} via "none" success') + return client + log.debug(f'Authenticate {client_id} via "none" failed') -def _validate_client(query_client, client_id, state=None, status_code=400): +def _validate_client(query_client, client_id, status_code=400): if client_id is None: - raise InvalidClientError(state=state, status_code=status_code) + raise InvalidClientError( + status_code=status_code, + description="Missing 'client_id' parameter.", + ) client = query_client(client_id) if not client: - raise InvalidClientError(state=state, status_code=status_code) + raise InvalidClientError( + status_code=status_code, + description="The client does not exist on this server.", + ) return client diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index c1f3ddca1..fe64089bb 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -1,42 +1,128 @@ +from authlib.common.errors import ContinueIteration +from authlib.deprecate import deprecate + from .authenticate_client import ClientAuthentication -from .errors import ( - OAuth2Error, - InvalidGrantError, - InvalidScopeError, - UnsupportedGrantTypeError, -) +from .endpoint import Endpoint +from .endpoint import EndpointRequest +from .errors import InvalidScopeError +from .errors import OAuth2Error +from .errors import UnsupportedGrantTypeError +from .errors import UnsupportedResponseTypeError +from .hooks import Hookable +from .hooks import hooked +from .requests import JsonRequest +from .requests import OAuth2Request from .util import scope_to_list -class AuthorizationServer(object): +class AuthorizationServer(Hookable): """Authorization server that handles Authorization Endpoint and Token Endpoint. - :param query_client: A function to get client by client_id. The client - model class MUST implement the methods described by - :class:`~authlib.oauth2.rfc6749.ClientMixin`. - :param save_token: A method to save tokens. - :param generate_token: A method to generate tokens. - :param metadata: A dict of Authorization Server Metadata + :param scopes_supported: A list of supported scopes by this authorization server. """ - def __init__(self, query_client, save_token, generate_token=None, metadata=None): - self.query_client = query_client - self.save_token = save_token - self.generate_token = generate_token - self.metadata = metadata + def __init__(self, scopes_supported=None): + super().__init__() + self.scopes_supported = scopes_supported + self._token_generators = {} self._client_auth = None self._authorization_grants = [] self._token_grants = [] self._endpoints = {} + self._extensions = [] - def authenticate_client(self, request, methods): + def query_client(self, client_id): + """Query OAuth client by client_id. The client model class MUST + implement the methods described by + :class:`~authlib.oauth2.rfc6749.ClientMixin`. + """ + raise NotImplementedError() + + def save_token(self, token, request): + """Define function to save the generated token into database.""" + raise NotImplementedError() + + def generate_token( + self, + grant_type, + client, + user=None, + scope=None, + expires_in=None, + include_refresh_token=True, + ): + """Generate the token dict. + + :param grant_type: current requested grant_type. + :param client: the client that making the request. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :param include_refresh_token: should refresh_token be included. + :return: Token dict + """ + # generator for a specified grant type + func = self._token_generators.get(grant_type) + if not func: + # default generator for all grant types + func = self._token_generators.get("default") + if not func: + raise RuntimeError("No configured token generator") + + return func( + grant_type=grant_type, + client=client, + user=user, + scope=scope, + expires_in=expires_in, + include_refresh_token=include_refresh_token, + ) + + def register_token_generator(self, grant_type, func): + """Register a function as token generator for the given ``grant_type``. + Developers MUST register a default token generator with a special + ``grant_type=default``:: + + def generate_bearer_token( + grant_type, + client, + user=None, + scope=None, + expires_in=None, + include_refresh_token=True, + ): + token = {"token_type": "Bearer", "access_token": ...} + if include_refresh_token: + token["refresh_token"] = ... + ... + return token + + + authorization_server.register_token_generator( + "default", generate_bearer_token + ) + + If you register a generator for a certain grant type, that generator will only works + for the given grant type:: + + authorization_server.register_token_generator( + "client_credentials", + generate_bearer_token, + ) + + :param grant_type: string name of the grant type + :param func: a function to generate token + """ + self._token_generators[grant_type] = func + + def authenticate_client(self, request, methods, endpoint="token"): """Authenticate client via HTTP request information with the given methods, such as ``client_secret_basic``, ``client_secret_post``. """ if self._client_auth is None and self.query_client: self._client_auth = ClientAuthentication(self.query_client) - return self._client_auth(request, methods) + return self._client_auth(request, methods, endpoint) def register_client_auth_method(self, method, func): """Add more client auth method. The default methods are: @@ -52,38 +138,35 @@ def register_client_auth_method(self, method, func): an example for this method:: def authenticate_client_via_custom(query_client, request): - client_id = request.headers['X-Client-Id'] + client_id = request.headers["X-Client-Id"] client = query_client(client_id) do_some_validation(client) return client + authorization_server.register_client_auth_method( - 'custom', authenticate_client_via_custom) + "custom", authenticate_client_via_custom + ) """ if self._client_auth is None and self.query_client: self._client_auth = ClientAuthentication(self.query_client) self._client_auth.register(method, func) - def get_translations(self, request): - """Return a translations instance used for i18n error messages. - Framework SHOULD implement this function. - """ - return None + def register_extension(self, extension): + self._extensions.append(extension(self)) - def get_error_uris(self, request): - """Return a dict of error uris mapping. Framework SHOULD implement - this function. - """ + def get_error_uri(self, request, error): + """Return a URI for the given error, framework may implement this method.""" return None def send_signal(self, name, *args, **kwargs): """Framework integration can re-implement this method to support signal system. """ - pass + raise NotImplementedError() - def create_oauth2_request(self, request): + def create_oauth2_request(self, request) -> OAuth2Request: """This method MUST be implemented in framework integrations. It is used to create an OAuth2Request instance. @@ -92,7 +175,7 @@ def create_oauth2_request(self, request): """ raise NotImplementedError() - def create_json_request(self, request): + def create_json_request(self, request) -> JsonRequest: """This method MUST be implemented in framework integrations. It is used to create an HttpRequest instance. @@ -105,15 +188,14 @@ def handle_response(self, status, body, headers): """Return HTTP response. Framework MUST implement this function.""" raise NotImplementedError() - def validate_requested_scope(self, scope, state=None): + def validate_requested_scope(self, scope): """Validate if requested scope is supported by Authorization Server. Developers CAN re-write this method to meet your needs. """ - if scope and self.metadata: - scopes_supported = self.metadata.get('scopes_supported') + if scope and self.scopes_supported: scopes = set(scope_to_list(scope)) - if scopes_supported and not set(scopes_supported).issuperset(scopes): - raise InvalidScopeError(state=state) + if not set(self.scopes_supported).issuperset(scopes): + raise InvalidScopeError() def register_grant(self, grant_cls, extensions=None): """Register a grant class into the endpoint registry. Developers @@ -129,32 +211,75 @@ def authenticate_user(self, credential): :param grant_cls: a grant class. :param extensions: extensions for the grant class. """ - if hasattr(grant_cls, 'check_authorization_endpoint'): + if hasattr(grant_cls, "check_authorization_endpoint"): self._authorization_grants.append((grant_cls, extensions)) - if hasattr(grant_cls, 'check_token_endpoint'): + if hasattr(grant_cls, "check_token_endpoint"): self._token_grants.append((grant_cls, extensions)) - def register_endpoint(self, endpoint_cls): + def register_endpoint(self, endpoint: type[Endpoint] | Endpoint): """Add extra endpoint to authorization server. e.g. RevocationEndpoint:: authorization_server.register_endpoint(RevocationEndpoint) - :param endpoint_cls: A endpoint class + :param endpoint: An endpoint class or instance. """ - self._endpoints[endpoint_cls.ENDPOINT_NAME] = endpoint_cls(self) + if isinstance(endpoint, type): + endpoint = endpoint(self) + else: + endpoint.server = self + + endpoints = self._endpoints.setdefault(endpoint.ENDPOINT_NAME, []) + endpoints.append(endpoint) + @hooked def get_authorization_grant(self, request): """Find the authorization grant for current request. :param request: OAuth2Request instance. :return: grant instance """ - for (grant_cls, extensions) in self._authorization_grants: + for grant_cls, extensions in self._authorization_grants: if grant_cls.check_authorization_endpoint(request): return _create_grant(grant_cls, extensions, request, self) - raise InvalidGrantError( - 'Response type {!r} is not supported'.format(request.response_type)) + + # Per RFC 6749 §4.1.2.1, only redirect with the error if the client + # exists and the redirect_uri has been validated against it. + redirect_uri = None + if client_id := request.payload.client_id: + if client := self.query_client(client_id): + if requested_uri := request.payload.redirect_uri: + if client.check_redirect_uri(requested_uri): + redirect_uri = requested_uri + else: + redirect_uri = client.get_default_redirect_uri() + + raise UnsupportedResponseTypeError( + f"The response type '{request.payload.response_type}' is not supported by the server.", + request.payload.response_type, + redirect_uri=redirect_uri, + ) + + def get_consent_grant(self, request=None, end_user=None): + """Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization. + """ + request = self.create_oauth2_request(request) + + try: + request.user = end_user + + grant = self.get_authorization_grant(request) + grant.validate_no_multiple_request_parameter(request) + grant.validate_consent_request() + + except OAuth2Error as error: + # REQUIRED if a "state" parameter was present in the client + # authorization request. The exact value received from the + # client. + error.state = request.payload.state + raise + return grant def get_token_grant(self, request): """Find the token grant for current request. @@ -162,31 +287,79 @@ def get_token_grant(self, request): :param request: OAuth2Request instance. :return: grant instance """ - for (grant_cls, extensions) in self._token_grants: - if grant_cls.check_token_endpoint(request) and \ - request.method in grant_cls.TOKEN_ENDPOINT_HTTP_METHODS: + for grant_cls, extensions in self._token_grants: + if grant_cls.check_token_endpoint(request): return _create_grant(grant_cls, extensions, request, self) - raise UnsupportedGrantTypeError( - 'Grant type {!r} is not supported'.format(request.grant_type)) + raise UnsupportedGrantTypeError(request.payload.grant_type) - def create_endpoint_response(self, name, request=None): - """Validate endpoint request and create endpoint response. + def validate_endpoint_request(self, name, request=None) -> EndpointRequest: + """Validate endpoint request and return the validated request object. + + Use this for interactive endpoints where you need to handle UI + between validation and response creation. :param name: Endpoint name - :param request: HTTP request instance. - :return: Response + :param request: HTTP request instance + :returns: Validated EndpointRequest object + :raises OAuth2Error: If validation fails + :raises RuntimeError: If endpoint not found + + Example:: + + req = server.validate_endpoint_request("end_session") + if req.needs_confirmation: + return render_template("confirm_logout.html", ...) + return server.create_endpoint_response("end_session", req) """ if name not in self._endpoints: - raise RuntimeError('There is no "{}" endpoint.'.format(name)) + raise RuntimeError(f"There is no '{name}' endpoint.") - endpoint = self._endpoints[name] + endpoint = self._endpoints[name][0] request = endpoint.create_endpoint_request(request) - try: - return self.handle_response(*endpoint(request)) - except OAuth2Error as error: - return self.handle_error_response(request, error) + return endpoint.validate_request(request) + + def create_endpoint_response(self, name, request=None): + """Validate endpoint request and create endpoint response. - def create_authorization_response(self, request=None, grant_user=None): + Can be called with: + - A raw HTTP request or None: validates and responds in one step + - A validated EndpointRequest: skips validation, creates response directly + + :param name: Endpoint name + :param request: HTTP request instance or validated EndpointRequest + :return: Response, or None if the endpoint returns None + """ + if name not in self._endpoints: + raise RuntimeError(f"There is no '{name}' endpoint.") + + endpoints = self._endpoints[name] + + # If request is already validated, create response directly + if isinstance(request, EndpointRequest): + endpoint = endpoints[0] + try: + result = endpoint.create_response(request) + if result is None: + return None + return self.handle_response(*result) + except OAuth2Error as error: + return self.handle_error_response(request.request, error) + + # Otherwise, validate and respond (existing behavior) + for endpoint in endpoints: + request = endpoint.create_endpoint_request(request) + try: + result = endpoint(request) + if result is None: + return None + return self.handle_response(*result) + except ContinueIteration: + continue + except OAuth2Error as error: + return self.handle_error_response(request, error) + + @hooked + def create_authorization_response(self, request=None, grant_user=None, grant=None): """Validate authorization request and create authorization response. :param request: HTTP request instance. @@ -194,18 +367,27 @@ def create_authorization_response(self, request=None, grant_user=None): it is None. :returns: Response """ - request = self.create_oauth2_request(request) - try: - grant = self.get_authorization_grant(request) - except InvalidGrantError as error: - return self.handle_error_response(request, error) + if not isinstance(request, OAuth2Request): + request = self.create_oauth2_request(request) + + if not grant: + deprecate("The 'grant' parameter will become mandatory.", version="1.8") + try: + grant = self.get_authorization_grant(request) + except UnsupportedResponseTypeError as error: + error.state = request.payload.state + return self.handle_error_response(request, error) try: redirect_uri = grant.validate_authorization_request() args = grant.create_authorization_response(redirect_uri, grant_user) - return self.handle_response(*args) + response = self.handle_response(*args) except OAuth2Error as error: - return self.handle_error_response(request, error) + error.state = request.payload.state + response = self.handle_error_response(request, error) + + grant.execute_hook("after_authorization_response", response) + return response def create_token_response(self, request=None): """Validate token request and create token response. @@ -226,10 +408,7 @@ def create_token_response(self, request=None): return self.handle_error_response(request, error) def handle_error_response(self, request, error): - return self.handle_response(*error( - translations=self.get_translations(request), - error_uris=self.get_error_uris(request) - )) + return self.handle_response(*error(self.get_error_uri(request, error))) def _create_grant(grant_cls, extensions, request, server): diff --git a/authlib/oauth2/rfc6749/endpoint.py b/authlib/oauth2/rfc6749/endpoint.py new file mode 100644 index 000000000..918294061 --- /dev/null +++ b/authlib/oauth2/rfc6749/endpoint.py @@ -0,0 +1,90 @@ +""" +authlib.oauth2.rfc6749.endpoint +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Base class for OAuth2 endpoints. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING +from typing import Any + +if TYPE_CHECKING: + from .requests import OAuth2Request + + +@dataclass +class EndpointRequest: + """Base class for validated endpoint requests. + + This object is returned by :meth:`Endpoint.validate_request` and contains + all validated information from the endpoint request. Subclasses add + endpoint-specific fields. + """ + + request: OAuth2Request + client: Any = None + + +class Endpoint: + """Base class for OAuth2 endpoints. + + Supports two modes of operation: + + **Automatic mode** (non-interactive endpoints): + Call ``server.create_endpoint_response(name)`` which validates the request + and creates the response in one step. + + **Interactive mode** (endpoints requiring user confirmation): + 1. Call ``server.validate_endpoint_request(name)`` to get a validated request + 2. Handle user interaction (e.g., show confirmation page) + 3. Call ``server.create_endpoint_response(name, validated_request)`` to complete + + Subclasses must implement :meth:`validate_request` and :meth:`create_response`. + """ + + #: Endpoint name used for registration + ENDPOINT_NAME: str | None = None + + def __init__(self, server=None): + self.server = server + + def create_endpoint_request(self, request): + """Convert framework request to OAuth2Request.""" + return self.server.create_oauth2_request(request) + + def validate_request(self, request: OAuth2Request) -> EndpointRequest: + """Validate the request and return a validated request object. + + :param request: The OAuth2Request to validate + :returns: EndpointRequest with validated data + :raises OAuth2Error: If validation fails + """ + raise NotImplementedError() + + def create_response( + self, validated_request: EndpointRequest + ) -> tuple[int, Any, list] | None: + """Create the HTTP response from a validated request. + + :param validated_request: The validated EndpointRequest + :returns: Tuple of (status_code, body, headers), or None if the + application should provide its own response + """ + raise NotImplementedError() + + def create_endpoint_response( + self, request: OAuth2Request + ) -> tuple[int, Any, list] | None: + """Validate and respond in one step (non-interactive mode). + + :param request: The OAuth2Request to process + :returns: Tuple of (status_code, body, headers), or None + """ + validated = self.validate_request(request) + return self.create_response(validated) + + def __call__(self, request: OAuth2Request) -> tuple[int, Any, list] | None: + return self.create_endpoint_response(request) diff --git a/authlib/oauth2/rfc6749/errors.py b/authlib/oauth2/rfc6749/errors.py index c2612aa61..87d73b3ab 100644 --- a/authlib/oauth2/rfc6749/errors.py +++ b/authlib/oauth2/rfc6749/errors.py @@ -1,55 +1,62 @@ -""" - authlib.oauth2.rfc6749.errors - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6749.errors. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Implementation for OAuth 2 Error Response. A basic error has - parameters: +Implementation for OAuth 2 Error Response. A basic error has +parameters: - error - REQUIRED. A single ASCII [USASCII] error code. +Error: +REQUIRED. A single ASCII [USASCII] error code. - error_description - OPTIONAL. Human-readable ASCII [USASCII] text providing - additional information, used to assist the client developer in - understanding the error that occurred. +error_description +OPTIONAL. Human-readable ASCII [USASCII] text providing +additional information, used to assist the client developer in +understanding the error that occurred. - error_uri - OPTIONAL. A URI identifying a human-readable web page with - information about the error, used to provide the client - developer with additional information about the error. - Values for the "error_uri" parameter MUST conform to the - URI-reference syntax and thus MUST NOT include characters - outside the set %x21 / %x23-5B / %x5D-7E. +error_uri +OPTIONAL. A URI identifying a human-readable web page with +information about the error, used to provide the client +developer with additional information about the error. +Values for the "error_uri" parameter MUST conform to the +URI-reference syntax and thus MUST NOT include characters +outside the set %x21 / %x23-5B / %x5D-7E. - state - REQUIRED if a "state" parameter was present in the client - authorization request. The exact value received from the - client. +state +REQUIRED if a "state" parameter was present in the client +authorization request. The exact value received from the +client. - https://tools.ietf.org/html/rfc6749#section-5.2 +https://tools.ietf.org/html/rfc6749#section-5.2 + +:copyright: (c) 2017 by Hsiaoming Yang. - :copyright: (c) 2017 by Hsiaoming Yang. """ -from authlib.oauth2.base import OAuth2Error + from authlib.common.security import is_secure_transport +from authlib.oauth2.base import OAuth2Error __all__ = [ - 'OAuth2Error', - 'InsecureTransportError', 'InvalidRequestError', - 'InvalidClientError', 'InvalidGrantError', - 'UnauthorizedClientError', 'UnsupportedGrantTypeError', - 'InvalidScopeError', 'AccessDeniedError', - 'MissingAuthorizationError', 'UnsupportedTokenTypeError', - 'MissingCodeException', 'MissingTokenException', - 'MissingTokenTypeException', 'MismatchingStateException', + "OAuth2Error", + "InsecureTransportError", + "InvalidRequestError", + "InvalidClientError", + "UnauthorizedClientError", + "InvalidGrantError", + "UnsupportedResponseTypeError", + "UnsupportedGrantTypeError", + "InvalidScopeError", + "AccessDeniedError", + "MissingAuthorizationError", + "UnsupportedTokenTypeError", + "MissingCodeException", + "MissingTokenException", + "MissingTokenTypeException", + "MismatchingStateException", ] class InsecureTransportError(OAuth2Error): - error = 'insecure_transport' - - def get_error_description(self): - return self.gettext('OAuth 2 MUST utilize https.') + error = "insecure_transport" + description = "OAuth 2 MUST utilize https." @classmethod def check(cls, uri): @@ -67,7 +74,8 @@ class InvalidRequestError(OAuth2Error): https://tools.ietf.org/html/rfc6749#section-5.2 """ - error = 'invalid_request' + + error = "invalid_request" class InvalidClientError(OAuth2Error): @@ -84,22 +92,21 @@ class InvalidClientError(OAuth2Error): https://tools.ietf.org/html/rfc6749#section-5.2 """ - error = 'invalid_client' + + error = "invalid_client" status_code = 400 def get_headers(self): - headers = super(InvalidClientError, self).get_headers() + headers = super().get_headers() if self.status_code == 401: error_description = self.get_error_description() # safe escape - error_description = error_description.replace('"', '|') + error_description = error_description.replace('"', "|") extras = [ - 'error="{}"'.format(self.error), - 'error_description="{}"'.format(error_description) + f'error="{self.error}"', + f'error_description="{error_description}"', ] - headers.append( - ('WWW-Authenticate', 'Basic ' + ', '.join(extras)) - ) + headers.append(("WWW-Authenticate", "Basic " + ", ".join(extras))) return headers @@ -112,16 +119,33 @@ class InvalidGrantError(OAuth2Error): https://tools.ietf.org/html/rfc6749#section-5.2 """ - error = 'invalid_grant' + + error = "invalid_grant" class UnauthorizedClientError(OAuth2Error): - """ The authenticated client is not authorized to use this + """The authenticated client is not authorized to use this authorization grant type. https://tools.ietf.org/html/rfc6749#section-5.2 """ - error = 'unauthorized_client' + + error = "unauthorized_client" + + +class UnsupportedResponseTypeError(OAuth2Error): + """The authorization server does not support obtaining + an access token using this method. + """ + + error = "unsupported_response_type" + + def __init__(self, response_type, *args, **kwargs): + super().__init__(*args, **kwargs) + self.response_type = response_type + + def get_error_description(self): + return f"response_type={self.response_type} is not supported" class UnsupportedGrantTypeError(OAuth2Error): @@ -130,7 +154,15 @@ class UnsupportedGrantTypeError(OAuth2Error): https://tools.ietf.org/html/rfc6749#section-5.2 """ - error = 'unsupported_grant_type' + + error = "unsupported_grant_type" + + def __init__(self, grant_type): + super().__init__() + self.grant_type = grant_type + + def get_error_description(self): + return f"grant_type={self.grant_type} is not supported" class InvalidScopeError(OAuth2Error): @@ -139,11 +171,9 @@ class InvalidScopeError(OAuth2Error): https://tools.ietf.org/html/rfc6749#section-5.2 """ - error = 'invalid_scope' - def get_error_description(self): - return self.gettext( - 'The requested scope is invalid, unknown, or malformed.') + error = "invalid_scope" + description = "The requested scope is invalid, unknown, or malformed." class AccessDeniedError(OAuth2Error): @@ -154,47 +184,64 @@ class AccessDeniedError(OAuth2Error): .. _`Section 4.1.2.1`: https://tools.ietf.org/html/rfc6749#section-4.1.2.1 """ - error = 'access_denied' - def get_error_description(self): - return self.gettext( - 'The resource owner or authorization server denied the request') + error = "access_denied" + description = "The resource owner or authorization server denied the request" # -- below are extended errors -- # -class MissingAuthorizationError(OAuth2Error): - error = 'missing_authorization' +class ForbiddenError(OAuth2Error): status_code = 401 - def get_error_description(self): - return self.gettext('Missing "Authorization" in headers.') + def __init__(self, auth_type=None, realm=None): + super().__init__() + self.auth_type = auth_type + self.realm = realm + + def get_headers(self): + headers = super().get_headers() + if not self.auth_type: + return headers + + extras = [] + if self.realm: + extras.append(f'realm="{self.realm}"') + extras.append(f'error="{self.error}"') + error_description = self.description + extras.append(f'error_description="{error_description}"') + headers.append(("WWW-Authenticate", f"{self.auth_type} " + ", ".join(extras))) + return headers -class UnsupportedTokenTypeError(OAuth2Error): - error = 'unsupported_token_type' - status_code = 401 +class MissingAuthorizationError(ForbiddenError): + error = "missing_authorization" + description = "Missing 'Authorization' in headers." + + +class UnsupportedTokenTypeError(ForbiddenError): + error = "unsupported_token_type" # -- exceptions for clients -- # class MissingCodeException(OAuth2Error): - error = 'missing_code' - description = 'Missing "code" in response.' + error = "missing_code" + description = "Missing 'code' in response." class MissingTokenException(OAuth2Error): - error = 'missing_token' - description = 'Missing "access_token" in response.' + error = "missing_token" + description = "Missing 'access_token' in response." class MissingTokenTypeException(OAuth2Error): - error = 'missing_token_type' - description = 'Missing "token_type" in response.' + error = "missing_token_type" + description = "Missing 'token_type' in response." class MismatchingStateException(OAuth2Error): - error = 'mismatching_state' - description = 'CSRF Warning! State not equal in request and response.' + error = "mismatching_state" + description = "CSRF Warning! State not equal in request and response." diff --git a/authlib/oauth2/rfc6749/grants/__init__.py b/authlib/oauth2/rfc6749/grants/__init__.py index b1797565f..f627c4189 100644 --- a/authlib/oauth2/rfc6749/grants/__init__.py +++ b/authlib/oauth2/rfc6749/grants/__init__.py @@ -1,37 +1,41 @@ """ - authlib.oauth2.rfc6749.grants - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +authlib.oauth2.rfc6749.grants +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Implementation for `Section 4`_ of "Obtaining Authorization". +Implementation for `Section 4`_ of "Obtaining Authorization". - To request an access token, the client obtains authorization from the - resource owner. The authorization is expressed in the form of an - authorization grant, which the client uses to request the access - token. OAuth defines four grant types: +To request an access token, the client obtains authorization from the +resource owner. The authorization is expressed in the form of an +authorization grant, which the client uses to request the access +token. OAuth defines four grant types: - 1. authorization code - 2. implicit - 3. resource owner password credentials - 4. client credentials. +1. authorization code +2. implicit +3. resource owner password credentials +4. client credentials. - It also provides an extension mechanism for defining additional grant - types. Authlib defines refresh_token as a grant type too. +It also provides an extension mechanism for defining additional grant +types. Authlib defines refresh_token as a grant type too. - .. _`Section 4`: https://tools.ietf.org/html/rfc6749#section-4 +.. _`Section 4`: https://tools.ietf.org/html/rfc6749#section-4 """ -# flake8: noqa - -from .base import BaseGrant, AuthorizationEndpointMixin, TokenEndpointMixin from .authorization_code import AuthorizationCodeGrant -from .implicit import ImplicitGrant -from .resource_owner_password_credentials import ResourceOwnerPasswordCredentialsGrant +from .base import AuthorizationEndpointMixin +from .base import BaseGrant +from .base import TokenEndpointMixin from .client_credentials import ClientCredentialsGrant +from .implicit import ImplicitGrant from .refresh_token import RefreshTokenGrant +from .resource_owner_password_credentials import ResourceOwnerPasswordCredentialsGrant __all__ = [ - 'BaseGrant', 'AuthorizationEndpointMixin', 'TokenEndpointMixin', - 'AuthorizationCodeGrant', 'ImplicitGrant', - 'ResourceOwnerPasswordCredentialsGrant', - 'ClientCredentialsGrant', 'RefreshTokenGrant', + "BaseGrant", + "AuthorizationEndpointMixin", + "TokenEndpointMixin", + "AuthorizationCodeGrant", + "ImplicitGrant", + "ResourceOwnerPasswordCredentialsGrant", + "ClientCredentialsGrant", + "RefreshTokenGrant", ] diff --git a/authlib/oauth2/rfc6749/grants/authorization_code.py b/authlib/oauth2/rfc6749/grants/authorization_code.py index 10599cb44..ebde2763a 100644 --- a/authlib/oauth2/rfc6749/grants/authorization_code.py +++ b/authlib/oauth2/rfc6749/grants/authorization_code.py @@ -1,15 +1,19 @@ import logging -from authlib.deprecate import deprecate -from authlib.common.urls import add_params_to_uri + from authlib.common.security import generate_token -from .base import BaseGrant, AuthorizationEndpointMixin, TokenEndpointMixin -from ..errors import ( - OAuth2Error, - UnauthorizedClientError, - InvalidClientError, - InvalidRequestError, - AccessDeniedError, -) +from authlib.common.urls import add_params_to_uri + +from ..errors import AccessDeniedError +from ..errors import InvalidClientError +from ..errors import InvalidGrantError +from ..errors import InvalidRequestError +from ..errors import InvalidScopeError +from ..errors import OAuth2Error +from ..errors import UnauthorizedClientError +from ..hooks import hooked +from .base import AuthorizationEndpointMixin +from .base import BaseGrant +from .base import TokenEndpointMixin log = logging.getLogger(__name__) @@ -48,14 +52,15 @@ class AuthorizationCodeGrant(BaseGrant, AuthorizationEndpointMixin, TokenEndpoin | |<---(E)----- Access Token -------------------' +---------+ (w/ Optional Refresh Token) """ + #: Allowed client auth methods for token endpoint - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post'] + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] #: Generated "code" length AUTHORIZATION_CODE_LENGTH = 48 - RESPONSE_TYPES = {'code'} - GRANT_TYPE = 'authorization_code' + RESPONSE_TYPES = {"code"} + GRANT_TYPE = "authorization_code" def validate_authorization_request(self): """The client constructs the request URI by adding the following @@ -107,7 +112,7 @@ def validate_authorization_request(self): """ return validate_code_authorization_request(self) - def create_authorization_response(self, redirect_uri, grant_user): + def create_authorization_response(self, redirect_uri: str, grant_user): """If the resource owner grants the access request, the authorization server issues an authorization code and delivers it to the client by adding the following parameters to the query component of the @@ -147,24 +152,21 @@ def create_authorization_response(self, redirect_uri, grant_user): :returns: (status_code, body, headers) """ if not grant_user: - raise AccessDeniedError(state=self.request.state, redirect_uri=redirect_uri) + raise AccessDeniedError(redirect_uri=redirect_uri) self.request.user = grant_user - if hasattr(self, 'create_authorization_code'): # pragma: no cover - deprecate('Use "generate_authorization_code" instead', '1.0') - client = self.request.client - code = self.create_authorization_code(client, grant_user, self.request) - else: - code = self.generate_authorization_code() - self.save_authorization_code(code, self.request) - - params = [('code', code)] - if self.request.state: - params.append(('state', self.request.state)) + + code = self.generate_authorization_code() + self.save_authorization_code(code, self.request) + + params = [("code", code)] + if self.request.payload.state: + params.append(("state", self.request.payload.state)) uri = add_params_to_uri(redirect_uri, params) - headers = [('Location', uri)] - return 302, '', headers + headers = [("Location", uri)] + return 302, "", headers + @hooked def validate_token_request(self): """The client makes a request to the token endpoint by sending the following parameters using the "application/x-www-form-urlencoded" @@ -211,33 +213,35 @@ def validate_token_request(self): # authenticate the client if client authentication is included client = self.authenticate_token_endpoint_client() - log.debug('Validate token request of %r', client) + log.debug("Validate token request of %r", client) if not client.check_grant_type(self.GRANT_TYPE): - raise UnauthorizedClientError() + raise UnauthorizedClientError( + f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'" + ) - code = self.request.form.get('code') + code = self.request.form.get("code") if code is None: - raise InvalidRequestError('Missing "code" in request.') + raise InvalidRequestError("Missing 'code' in request.") # ensure that the authorization code was issued to the authenticated # confidential client, or if the client is public, ensure that the # code was issued to "client_id" in the request authorization_code = self.query_authorization_code(code, client) if not authorization_code: - raise InvalidRequestError('Invalid "code" in request.') + raise InvalidGrantError("Invalid 'code' in request.") # validate redirect_uri parameter - log.debug('Validate token redirect_uri of %r', client) - redirect_uri = self.request.redirect_uri + log.debug("Validate token redirect_uri of %r", client) + redirect_uri = self.request.payload.redirect_uri original_redirect_uri = authorization_code.get_redirect_uri() if original_redirect_uri and redirect_uri != original_redirect_uri: - raise InvalidRequestError('Invalid "redirect_uri" in request.') + raise InvalidGrantError("Invalid 'redirect_uri' in request.") # save for create_token_response self.request.client = client - self.request.credential = authorization_code - self.execute_hook('after_validate_token_request') + self.request.authorization_code = authorization_code + @hooked def create_token_response(self): """If the access token request is valid and authorized, the authorization server issues an access token and optional refresh @@ -267,28 +271,27 @@ def create_token_response(self): .. _`Section 4.1.4`: https://tools.ietf.org/html/rfc6749#section-4.1.4 """ client = self.request.client - authorization_code = self.request.credential + authorization_code = self.request.authorization_code user = self.authenticate_user(authorization_code) if not user: - raise InvalidRequestError('There is no "user" for this code.') + raise InvalidGrantError("There is no 'user' for this code.") + self.request.user = user scope = authorization_code.get_scope() token = self.generate_token( user=user, scope=scope, - include_refresh_token=client.check_grant_type('refresh_token'), + include_refresh_token=client.check_grant_type("refresh_token"), ) - log.debug('Issue token %r to %r', token, client) + log.debug("Issue token %r to %r", token, client) - self.request.user = user self.save_token(token) - self.execute_hook('process_token', token=token) self.delete_authorization_code(authorization_code) return 200, token, self.TOKEN_RESPONSE_HEADER def generate_authorization_code(self): - """"The method to generate "code" value for authorization code data. + """ "The method to generate "code" value for authorization code data. Developers may rewrite this method, or customize the code length with:: class MyAuthorizationCodeGrant(AuthorizationCodeGrant): @@ -305,11 +308,16 @@ def save_authorization_code(self, code, request): item = AuthorizationCode( code=code, client_id=client.client_id, - redirect_uri=request.redirect_uri, + redirect_uri=request.payload.redirect_uri, scope=request.scope, user_id=request.user.id, ) item.save() + + .. note:: Use ``request.scope`` instead of ``request.payload.scope`` to get + the resolved scope. Per RFC 6749 Section 3.3, if the client omits the + scope parameter, the server uses a default value from + ``client.get_allowed_scope()``. """ raise NotImplementedError() @@ -324,9 +332,6 @@ def query_authorization_code(self, code, client): :param client: client related to this code. :return: authorization_code object """ - if hasattr(self, 'parse_authorization_code'): - deprecate('Use "query_authorization_code" instead', '1.0') - return self.parse_authorization_code(code, client) raise NotImplementedError() def delete_authorization_code(self, authorization_code): @@ -345,7 +350,7 @@ def authenticate_user(self, authorization_code): MUST implement this method in subclass, e.g.:: def authenticate_user(self, authorization_code): - return User.query.get(authorization_code.user_id) + return User.get(authorization_code.user_id) :param authorization_code: AuthorizationCode object :return: user @@ -354,30 +359,41 @@ def authenticate_user(self, authorization_code): def validate_code_authorization_request(grant): - client_id = grant.request.client_id - log.debug('Validate authorization request of %r', client_id) + request = grant.request + client_id = request.payload.client_id + log.debug("Validate authorization request of %r", client_id) if client_id is None: - raise InvalidClientError(state=grant.request.state) + raise InvalidClientError( + description="Missing 'client_id' parameter.", + ) client = grant.server.query_client(client_id) if not client: - raise InvalidClientError(state=grant.request.state) + raise InvalidClientError( + description="The client does not exist on this server.", + ) - redirect_uri = grant.validate_authorization_redirect_uri(grant.request, client) - response_type = grant.request.response_type + redirect_uri = grant.validate_authorization_redirect_uri(request, client) + response_type = request.payload.response_type if not client.check_response_type(response_type): raise UnauthorizedClientError( - 'The client is not authorized to use ' - '"response_type={}"'.format(response_type), - state=grant.request.state, + f"The client is not authorized to use 'response_type={response_type}'", redirect_uri=redirect_uri, ) - try: - grant.request.client = client + grant.request.client = client + + @hooked + def validate_authorization_request_payload(grant, redirect_uri): grant.validate_requested_scope() - grant.execute_hook('after_validate_authorization_request') + scope = client.get_allowed_scope(request.payload.scope) + if scope is None: + raise InvalidScopeError() + request.scope = scope + + try: + validate_authorization_request_payload(grant, redirect_uri) except OAuth2Error as error: error.redirect_uri = redirect_uri raise error diff --git a/authlib/oauth2/rfc6749/grants/base.py b/authlib/oauth2/rfc6749/grants/base.py index 9fe03c904..bd1de087d 100644 --- a/authlib/oauth2/rfc6749/grants/base.py +++ b/authlib/oauth2/rfc6749/grants/base.py @@ -1,10 +1,14 @@ from authlib.consts import default_json_headers + from ..errors import InvalidRequestError +from ..hooks import Hookable +from ..hooks import hooked +from ..requests import OAuth2Request -class BaseGrant(object): +class BaseGrant(Hookable): #: Allowed client auth methods for token endpoint - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic'] + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic"] #: Designed for which "grant_type" GRANT_TYPE = None @@ -15,32 +19,29 @@ class BaseGrant(object): # https://tools.ietf.org/html/rfc4627 TOKEN_RESPONSE_HEADER = default_json_headers - def __init__(self, request, server): + def __init__(self, request: OAuth2Request, server): + super().__init__() + self.prompt = None + self.redirect_uri = None self.request = request self.server = server - self._hooks = { - 'after_validate_authorization_request': set(), - 'after_validate_consent_request': set(), - 'after_validate_token_request': set(), - 'process_token': set(), - } @property def client(self): return self.request.client - def generate_token(self, user=None, scope=None, grant_type=None, - expires_in=None, include_refresh_token=True): - + def generate_token( + self, + user=None, + scope=None, + grant_type=None, + expires_in=None, + include_refresh_token=True, + ): if grant_type is None: grant_type = self.GRANT_TYPE - - client = self.request.client - if scope is not None: - scope = client.get_allowed_scope(scope) - return self.server.generate_token( - client=client, + client=self.request.client, grant_type=grant_type, user=user, scope=scope, @@ -69,11 +70,9 @@ def authenticate_token_endpoint_client(self): :return: client """ client = self.server.authenticate_client( - self.request, - self.TOKEN_ENDPOINT_AUTH_METHODS) - self.server.send_signal( - 'after_authenticate_client', - client=client, grant=self) + self.request, self.TOKEN_ENDPOINT_AUTH_METHODS + ) + self.server.send_signal("after_authenticate_client", client=client, grant=self) return client def save_token(self, token): @@ -82,31 +81,23 @@ def save_token(self, token): def validate_requested_scope(self): """Validate if requested scope is supported by Authorization Server.""" - scope = self.request.scope - state = self.request.state - return self.server.validate_requested_scope(scope, state) - - def register_hook(self, hook_type, hook): - if hook_type not in self._hooks: - raise ValueError('Hook type %s is not in %s.', - hook_type, self._hooks) - self._hooks[hook_type].add(hook) - - def execute_hook(self, hook_type, *args, **kwargs): - for hook in self._hooks[hook_type]: - hook(self, *args, **kwargs) + scope = self.request.payload.scope + return self.server.validate_requested_scope(scope) -class TokenEndpointMixin(object): +class TokenEndpointMixin: #: Allowed HTTP methods of this token endpoint - TOKEN_ENDPOINT_HTTP_METHODS = ['POST'] + TOKEN_ENDPOINT_HTTP_METHODS = ["POST"] #: Designed for which "grant_type" GRANT_TYPE = None @classmethod - def check_token_endpoint(cls, request): - return request.grant_type == cls.GRANT_TYPE + def check_token_endpoint(cls, request: OAuth2Request): + return ( + request.payload.grant_type == cls.GRANT_TYPE + and request.method in cls.TOKEN_ENDPOINT_HTTP_METHODS + ) def validate_token_request(self): raise NotImplementedError() @@ -115,37 +106,53 @@ def create_token_response(self): raise NotImplementedError() -class AuthorizationEndpointMixin(object): +class AuthorizationEndpointMixin: RESPONSE_TYPES = set() ERROR_RESPONSE_FRAGMENT = False @classmethod - def check_authorization_endpoint(cls, request): - return request.response_type in cls.RESPONSE_TYPES + def check_authorization_endpoint(cls, request: OAuth2Request): + return request.payload.response_type in cls.RESPONSE_TYPES @staticmethod - def validate_authorization_redirect_uri(request, client): - if request.redirect_uri: - if not client.check_redirect_uri(request.redirect_uri): + def validate_authorization_redirect_uri(request: OAuth2Request, client): + if request.payload.redirect_uri: + if not client.check_redirect_uri(request.payload.redirect_uri): raise InvalidRequestError( - 'Redirect URI {!r} is not supported by client.'.format(request.redirect_uri), - state=request.state, + f"Redirect URI {request.payload.redirect_uri} is not supported by client.", ) - return request.redirect_uri + return request.payload.redirect_uri else: redirect_uri = client.get_default_redirect_uri() if not redirect_uri: raise InvalidRequestError( - 'Missing "redirect_uri" in request.' + "Missing 'redirect_uri' in request.", state=request.payload.state ) return redirect_uri + @staticmethod + def validate_no_multiple_request_parameter(request: OAuth2Request): + """For the Authorization Endpoint, request and response parameters MUST NOT be included + more than once. Per `Section 3.1`_. + + .. _`Section 3.1`: https://tools.ietf.org/html/rfc6749#section-3.1 + """ + datalist = request.payload.datalist + parameters = ["response_type", "client_id", "redirect_uri", "scope", "state"] + for param in parameters: + if len(datalist.get(param, [])) > 1: + raise InvalidRequestError( + f"Multiple '{param}' in request.", state=request.payload.state + ) + + @hooked def validate_consent_request(self): redirect_uri = self.validate_authorization_request() - self.execute_hook('after_validate_consent_request', redirect_uri) + self.redirect_uri = redirect_uri + return redirect_uri def validate_authorization_request(self): raise NotImplementedError() - def create_authorization_response(self, redirect_uri, grant_user): + def create_authorization_response(self, redirect_uri: str, grant_user): raise NotImplementedError() diff --git a/authlib/oauth2/rfc6749/grants/client_credentials.py b/authlib/oauth2/rfc6749/grants/client_credentials.py index 784a37028..3b0ff7d2f 100644 --- a/authlib/oauth2/rfc6749/grants/client_credentials.py +++ b/authlib/oauth2/rfc6749/grants/client_credentials.py @@ -1,6 +1,9 @@ import logging -from .base import BaseGrant, TokenEndpointMixin + from ..errors import UnauthorizedClientError +from ..hooks import hooked +from .base import BaseGrant +from .base import TokenEndpointMixin log = logging.getLogger(__name__) @@ -25,7 +28,8 @@ class ClientCredentialsGrant(BaseGrant, TokenEndpointMixin): https://tools.ietf.org/html/rfc6749#section-4.4 """ - GRANT_TYPE = 'client_credentials' + + GRANT_TYPE = "client_credentials" def validate_token_request(self): """The client makes a request to the token endpoint by adding the @@ -58,18 +62,20 @@ def validate_token_request(self): The authorization server MUST authenticate the client. """ - # ignore validate for grant_type, since it is validated by # check_token_endpoint client = self.authenticate_token_endpoint_client() - log.debug('Validate token request of %r', client) + log.debug("Validate token request of %r", client) if not client.check_grant_type(self.GRANT_TYPE): - raise UnauthorizedClientError() + raise UnauthorizedClientError( + f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'" + ) self.request.client = client self.validate_requested_scope() + @hooked def create_token_response(self): """If the access token request is valid and authorized, the authorization server issues an access token as described in @@ -95,9 +101,9 @@ def create_token_response(self): :returns: (status_code, body, headers) """ - client = self.request.client - token = self.generate_token(scope=self.request.scope, include_refresh_token=False) - log.debug('Issue token %r to %r', token, client) + token = self.generate_token( + scope=self.request.payload.scope, include_refresh_token=False + ) + log.debug("Issue token %r to %r", token, self.client) self.save_token(token) - self.execute_hook('process_token', self, token=token) return 200, token, self.TOKEN_RESPONSE_HEADER diff --git a/authlib/oauth2/rfc6749/grants/implicit.py b/authlib/oauth2/rfc6749/grants/implicit.py index 75b12be43..c58a0a538 100644 --- a/authlib/oauth2/rfc6749/grants/implicit.py +++ b/authlib/oauth2/rfc6749/grants/implicit.py @@ -1,11 +1,14 @@ import logging + from authlib.common.urls import add_params_to_uri -from .base import BaseGrant, AuthorizationEndpointMixin -from ..errors import ( - OAuth2Error, - UnauthorizedClientError, - AccessDeniedError, -) + +from ..errors import AccessDeniedError +from ..errors import InvalidScopeError +from ..errors import OAuth2Error +from ..errors import UnauthorizedClientError +from ..hooks import hooked +from .base import AuthorizationEndpointMixin +from .base import BaseGrant log = logging.getLogger(__name__) @@ -66,15 +69,17 @@ class ImplicitGrant(BaseGrant, AuthorizationEndpointMixin): | | +---------+ """ + #: authorization_code grant type has authorization endpoint AUTHORIZATION_ENDPOINT = True #: Allowed client auth methods for token endpoint - TOKEN_ENDPOINT_AUTH_METHODS = ['none'] + TOKEN_ENDPOINT_AUTH_METHODS = ["none"] - RESPONSE_TYPES = {'token'} - GRANT_TYPE = 'implicit' + RESPONSE_TYPES = {"token"} + GRANT_TYPE = "implicit" ERROR_RESPONSE_FRAGMENT = True + @hooked def validate_authorization_request(self): """The client constructs the request URI by adding the following parameters to the query component of the authorization endpoint URI @@ -121,17 +126,14 @@ def validate_authorization_request(self): # The implicit grant type is optimized for public clients client = self.authenticate_token_endpoint_client() - log.debug('Validate authorization request of %r', client) + log.debug("Validate authorization request of %r", client) - redirect_uri = self.validate_authorization_redirect_uri( - self.request, client) + redirect_uri = self.validate_authorization_redirect_uri(self.request, client) - response_type = self.request.response_type + response_type = self.request.payload.response_type if not client.check_response_type(response_type): raise UnauthorizedClientError( - 'The client is not authorized to use ' - '"response_type={}"'.format(response_type), - state=self.request.state, + f"The client is not authorized to use 'response_type={response_type}'", redirect_uri=redirect_uri, redirect_fragment=True, ) @@ -139,13 +141,17 @@ def validate_authorization_request(self): try: self.request.client = client self.validate_requested_scope() - self.execute_hook('after_validate_authorization_request') + scope = client.get_allowed_scope(self.request.payload.scope) + if scope is None: + raise InvalidScopeError() + self.request.scope = scope except OAuth2Error as error: error.redirect_uri = redirect_uri error.redirect_fragment = True raise error return redirect_uri + @hooked def create_authorization_response(self, redirect_uri, grant_user): """If the resource owner grants the access request, the authorization server issues an access token and delivers it to the client by adding @@ -202,7 +208,7 @@ def create_authorization_response(self, redirect_uri, grant_user): resource owner, otherwise pass None. :returns: (status_code, body, headers) """ - state = self.request.state + state = self.request.payload.state if grant_user: self.request.user = grant_user token = self.generate_token( @@ -210,20 +216,15 @@ def create_authorization_response(self, redirect_uri, grant_user): scope=self.request.scope, include_refresh_token=False, ) - log.debug('Grant token %r to %r', token, self.request.client) + log.debug("Grant token %r to %r", token, self.request.client) self.save_token(token) - self.execute_hook('process_token', token=token) params = [(k, token[k]) for k in token] if state: - params.append(('state', state)) + params.append(("state", state)) uri = add_params_to_uri(redirect_uri, params, fragment=True) - headers = [('Location', uri)] - return 302, '', headers + headers = [("Location", uri)] + return 302, "", headers else: - raise AccessDeniedError( - state=state, - redirect_uri=redirect_uri, - redirect_fragment=True - ) + raise AccessDeniedError(redirect_uri=redirect_uri, redirect_fragment=True) diff --git a/authlib/oauth2/rfc6749/grants/refresh_token.py b/authlib/oauth2/rfc6749/grants/refresh_token.py index d29f4f950..d1e502dba 100644 --- a/authlib/oauth2/rfc6749/grants/refresh_token.py +++ b/authlib/oauth2/rfc6749/grants/refresh_token.py @@ -1,22 +1,23 @@ -""" - authlib.oauth2.rfc6749.grants.refresh_token - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6749.grants.refresh_token. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - A special grant endpoint for refresh_token grant_type. Refreshing an - Access Token per `Section 6`_. +A special grant endpoint for refresh_token grant_type. Refreshing an +Access Token per `Section 6`_. - .. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 +.. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 """ import logging -from .base import BaseGrant, TokenEndpointMixin + +from ..errors import InvalidGrantError +from ..errors import InvalidRequestError +from ..errors import InvalidScopeError +from ..errors import UnauthorizedClientError +from ..hooks import hooked from ..util import scope_to_list -from ..errors import ( - InvalidRequestError, - InvalidScopeError, - InvalidGrantError, - UnauthorizedClientError, -) +from .base import BaseGrant +from .base import TokenEndpointMixin + log = logging.getLogger(__name__) @@ -26,7 +27,8 @@ class RefreshTokenGrant(BaseGrant, TokenEndpointMixin): .. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 """ - GRANT_TYPE = 'refresh_token' + + GRANT_TYPE = "refresh_token" #: The authorization server MAY issue a new refresh token INCLUDE_NEW_REFRESH_TOKEN = False @@ -36,27 +38,27 @@ def _validate_request_client(self): # client that was issued client credentials (or with other # authentication requirements) client = self.authenticate_token_endpoint_client() - log.debug('Validate token request of %r', client) + log.debug("Validate token request of %r", client) if not client.check_grant_type(self.GRANT_TYPE): - raise UnauthorizedClientError() + raise UnauthorizedClientError( + f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'" + ) return client def _validate_request_token(self, client): - refresh_token = self.request.form.get('refresh_token') + refresh_token = self.request.form.get("refresh_token") if refresh_token is None: - raise InvalidRequestError( - 'Missing "refresh_token" in request.', - ) + raise InvalidRequestError("Missing 'refresh_token' in request.") token = self.authenticate_refresh_token(refresh_token) - if not token or token.get_client_id() != client.get_client_id(): + if not token or not token.check_client(client): raise InvalidGrantError() return token def _validate_token_scope(self, token): - scope = self.request.scope + scope = self.request.payload.scope if not scope: return @@ -104,40 +106,38 @@ def validate_token_request(self): """ client = self._validate_request_client() self.request.client = client - token = self._validate_request_token(client) - self._validate_token_scope(token) - self.request.credential = token + refresh_token = self._validate_request_token(client) + self._validate_token_scope(refresh_token) + self.request.refresh_token = refresh_token + @hooked def create_token_response(self): """If valid and authorized, the authorization server issues an access token as described in Section 5.1. If the request failed verification or is invalid, the authorization server returns an error response as described in Section 5.2. """ - credential = self.request.credential - user = self.authenticate_user(credential) + refresh_token = self.request.refresh_token + user = self.authenticate_user(refresh_token) if not user: - raise InvalidRequestError('There is no "user" for this token.') + raise InvalidRequestError("There is no 'user' for this token.") client = self.request.client - token = self.issue_token(user, credential) - log.debug('Issue token %r to %r', token, client) + token = self.issue_token(user, refresh_token) + log.debug("Issue token %r to %r", token, client) self.request.user = user self.save_token(token) - self.execute_hook('process_token', token=token) - self.revoke_old_credential(credential) + self.revoke_old_credential(refresh_token) return 200, token, self.TOKEN_RESPONSE_HEADER - def issue_token(self, user, credential): - expires_in = credential.get_expires_in() - scope = self.request.scope + def issue_token(self, user, refresh_token): + scope = self.request.payload.scope if not scope: - scope = credential.get_scope() + scope = refresh_token.get_scope() token = self.generate_token( user=user, - expires_in=expires_in, scope=scope, include_refresh_token=self.INCLUDE_NEW_REFRESH_TOKEN, ) @@ -148,36 +148,36 @@ def authenticate_refresh_token(self, refresh_token): implement this method in subclass:: def authenticate_refresh_token(self, refresh_token): - item = Token.get(refresh_token=refresh_token) - if item and item.is_refresh_token_active(): - return item + token = Token.get(refresh_token=refresh_token) + if token and not token.refresh_token_revoked: + return token :param refresh_token: The refresh token issued to the client :return: token """ raise NotImplementedError() - def authenticate_user(self, credential): + def authenticate_user(self, refresh_token): """Authenticate the user related to this credential. Developers MUST implement this method in subclass:: def authenticate_user(self, credential): - return User.query.get(credential.user_id) + return User.get(credential.user_id) - :param credential: Token object + :param refresh_token: Token object :return: user """ raise NotImplementedError() - def revoke_old_credential(self, credential): + def revoke_old_credential(self, refresh_token): """The authorization server MAY revoke the old refresh token after issuing a new refresh token to the client. Developers MUST implement this method in subclass:: - def revoke_old_credential(self, credential): + def revoke_old_credential(self, refresh_token): credential.revoked = True credential.save() - :param credential: Token object + :param refresh_token: Token object """ raise NotImplementedError() diff --git a/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py b/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py index df31c867c..ce1c487c4 100644 --- a/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py +++ b/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py @@ -1,9 +1,10 @@ import logging -from .base import BaseGrant, TokenEndpointMixin -from ..errors import ( - UnauthorizedClientError, - InvalidRequestError, -) + +from ..errors import InvalidRequestError +from ..errors import UnauthorizedClientError +from ..hooks import hooked +from .base import BaseGrant +from .base import TokenEndpointMixin log = logging.getLogger(__name__) @@ -11,7 +12,7 @@ class ResourceOwnerPasswordCredentialsGrant(BaseGrant, TokenEndpointMixin): """The resource owner password credentials grant type is suitable in cases where the resource owner has a trust relationship with the - client, such as the device operating system or a highly privileged + client, such as the device operating system or a highly privileged. application. The authorization server should take special care when enabling this grant type and only allow it when other flows are not @@ -42,7 +43,8 @@ class ResourceOwnerPasswordCredentialsGrant(BaseGrant, TokenEndpointMixin): | | (w/ Optional Refresh Token) | | +---------+ +---------------+ """ - GRANT_TYPE = 'password' + + GRANT_TYPE = "password" def validate_token_request(self): """The client makes a request to the token endpoint by adding the @@ -84,30 +86,30 @@ def validate_token_request(self): # ignore validate for grant_type, since it is validated by # check_token_endpoint client = self.authenticate_token_endpoint_client() - log.debug('Validate token request of %r', client) + log.debug("Validate token request of %r", client) if not client.check_grant_type(self.GRANT_TYPE): - raise UnauthorizedClientError() + raise UnauthorizedClientError( + f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'" + ) params = self.request.form - if 'username' not in params: - raise InvalidRequestError('Missing "username" in request.') - if 'password' not in params: - raise InvalidRequestError('Missing "password" in request.') - - log.debug('Authenticate user of %r', params['username']) - user = self.authenticate_user( - params['username'], - params['password'] - ) + if "username" not in params: + raise InvalidRequestError("Missing 'username' in request.") + if "password" not in params: + raise InvalidRequestError("Missing 'password' in request.") + + log.debug("Authenticate user of %r", params["username"]) + user = self.authenticate_user(params["username"], params["password"]) if not user: raise InvalidRequestError( - 'Invalid "username" or "password" in request.', + "Invalid 'username' or 'password' in request.", ) self.request.client = client self.request.user = user self.validate_requested_scope() + @hooked def create_token_response(self): """If the access token request is valid and authorized, the authorization server issues an access token and optional refresh @@ -135,20 +137,19 @@ def create_token_response(self): :returns: (status_code, body, headers) """ user = self.request.user - scope = self.request.scope + scope = self.request.payload.scope token = self.generate_token(user=user, scope=scope) - log.debug('Issue token %r to %r', token, self.request.client) + log.debug("Issue token %r to %r", token, self.client) self.save_token(token) - self.execute_hook('process_token', token=token) return 200, token, self.TOKEN_RESPONSE_HEADER def authenticate_user(self, username, password): - """validate the resource owner password credentials using its + """Validate the resource owner password credentials using its existing password validation algorithm:: def authenticate_user(self, username, password): user = get_user_by_username(username) if user.check_password(password): - return user + return user """ raise NotImplementedError() diff --git a/authlib/oauth2/rfc6749/hooks.py b/authlib/oauth2/rfc6749/hooks.py new file mode 100644 index 000000000..376f0e187 --- /dev/null +++ b/authlib/oauth2/rfc6749/hooks.py @@ -0,0 +1,37 @@ +from collections import defaultdict + + +class Hookable: + _hooks = None + + def __init__(self): + self._hooks = defaultdict(set) + + def register_hook(self, hook_type, hook): + self._hooks[hook_type].add(hook) + + def execute_hook(self, hook_type, *args, **kwargs): + for hook in self._hooks[hook_type]: + hook(self, *args, **kwargs) + + +def hooked(func=None, before=None, after=None): + """Execute hooks before and after the decorated method.""" + + def decorator(func): + before_name = before or f"before_{func.__name__}" + after_name = after or f"after_{func.__name__}" + + def wrapper(self, *args, **kwargs): + self.execute_hook(before_name, *args, **kwargs) + result = func(self, *args, **kwargs) + self.execute_hook(after_name, result) + return result + + return wrapper + + # The decorator has been called without parenthesis + if callable(func): + return decorator(func) + + return decorator diff --git a/authlib/oauth2/rfc6749/models.py b/authlib/oauth2/rfc6749/models.py index 0f86a4aa2..f3eaef662 100644 --- a/authlib/oauth2/rfc6749/models.py +++ b/authlib/oauth2/rfc6749/models.py @@ -1,12 +1,11 @@ -""" - authlib.oauth2.rfc6749.models - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6749.models. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - This module defines how to construct Client, AuthorizationCode and Token. +This module defines how to construct Client, AuthorizationCode and Token. """ -class ClientMixin(object): +class ClientMixin: """Implementation of OAuth 2 Client described in `Section 2`_ with some methods to help validation. A client has at least these information: @@ -46,7 +45,7 @@ def get_allowed_scope(self, scope): def get_allowed_scope(self, scope): if not scope: - return '' + return "" allowed = set(scope_to_list(self.scope)) return list_to_scope([s for s in scope.split() if s in allowed]) @@ -68,32 +67,33 @@ def check_redirect_uri(self, redirect_uri): """ raise NotImplementedError() - def has_client_secret(self): - """A method returns that if the client has ``client_secret`` value. - If the value is in ``client_secret`` column:: - - def has_client_secret(self): - return bool(self.client_secret) - - :return: bool - """ - raise NotImplementedError() - def check_client_secret(self, client_secret): """Check client_secret matching with the client. For instance, in the client table, the column is called ``client_secret``:: + import secrets + + def check_client_secret(self, client_secret): - return self.client_secret == client_secret + return secrets.compare_digest(self.client_secret, client_secret) :param client_secret: A string of client secret :return: bool """ raise NotImplementedError() - def check_token_endpoint_auth_method(self, method): - """Check client ``token_endpoint_auth_method`` defined via `RFC7591`_. - Values defined by this specification are: + def check_endpoint_auth_method(self, method, endpoint): + """Check if client support the given method for the given endpoint. + There is a ``token_endpoint_auth_method`` defined via `RFC7591`_. + Developers MAY re-implement this method with:: + + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == "token": + # if client table has ``token_endpoint_auth_method`` + return self.token_endpoint_auth_method == method + return True + + Method values defined by this specification are: * "none": The client is a public client as defined in OAuth 2.0, and does not have a client secret. @@ -141,7 +141,7 @@ def check_grant_type(self, grant_type): raise NotImplementedError() -class AuthorizationCodeMixin(object): +class AuthorizationCodeMixin: def get_redirect_uri(self): """A method to get authorization code's ``redirect_uri``. For instance, the database table for authorization code has a @@ -166,15 +166,15 @@ def get_scope(self): raise NotImplementedError() -class TokenMixin(object): - def get_client_id(self): - """A method to return client_id of the token. For instance, the value - in database is saved in a column called ``client_id``:: +class TokenMixin: + def check_client(self, client): + """A method to check if this token is issued to the given client. + For instance, ``client_id`` is saved on token table:: - def get_client_id(self): - return self.client_id + def check_client(self, client): + return self.client_id == client.client_id - :return: string + :return: bool """ raise NotImplementedError() @@ -200,13 +200,44 @@ def get_expires_in(self): """ raise NotImplementedError() - def get_expires_at(self): - """A method to get the value when this token will be expired. e.g. - it would be:: + def is_expired(self): + """A method to define if this token is expired. For instance, + there is a column ``expired_at`` in the table:: - def get_expires_at(self): - return self.created_at + self.expires_in + def is_expired(self): + return self.expired_at < now - :return: timestamp int + :return: boolean + """ + raise NotImplementedError() + + def is_revoked(self): + """A method to define if this token is revoked. For instance, + there is a boolean column ``revoked`` in the table:: + + def is_revoked(self): + return self.revoked + + :return: boolean + """ + raise NotImplementedError() + + def get_user(self): + """A method to get the user object associated with this token: + + .. code-block:: + + def get_user(self): + return User.get(self.user_id) + """ + raise NotImplementedError() + + def get_client(self) -> ClientMixin: + """A method to get the client object associated with this token: + + .. code-block:: + + def get_client(self): + return Client.get(self.client_id) """ raise NotImplementedError() diff --git a/authlib/oauth2/rfc6749/parameters.py b/authlib/oauth2/rfc6749/parameters.py index 20461fdbc..a575fe726 100644 --- a/authlib/oauth2/rfc6749/parameters.py +++ b/authlib/oauth2/rfc6749/parameters.py @@ -1,20 +1,18 @@ -from authlib.common.urls import ( - urlparse, - add_params_to_uri, - add_params_to_qs, -) from authlib.common.encoding import to_unicode -from .errors import ( - MissingCodeException, - MissingTokenException, - MissingTokenTypeException, - MismatchingStateException, -) +from authlib.common.urls import add_params_to_qs +from authlib.common.urls import add_params_to_uri +from authlib.common.urls import urlparse + +from .errors import MismatchingStateException +from .errors import MissingCodeException +from .errors import MissingTokenException +from .errors import MissingTokenTypeException from .util import list_to_scope -def prepare_grant_uri(uri, client_id, response_type, redirect_uri=None, - scope=None, state=None, **kwargs): +def prepare_grant_uri( + uri, client_id, response_type, redirect_uri=None, scope=None, state=None, **kwargs +): """Prepare the authorization grant request URI. The client constructs the request URI by adding the following @@ -47,26 +45,28 @@ def prepare_grant_uri(uri, client_id, response_type, redirect_uri=None, .. _`Section 3.3`: https://tools.ietf.org/html/rfc6749#section-3.3 .. _`section 10.12`: https://tools.ietf.org/html/rfc6749#section-10.12 """ - params = [ - ('response_type', response_type), - ('client_id', client_id) - ] + params = [("response_type", response_type), ("client_id", client_id)] if redirect_uri: - params.append(('redirect_uri', redirect_uri)) + params.append(("redirect_uri", redirect_uri)) if scope: - params.append(('scope', list_to_scope(scope))) + params.append(("scope", list_to_scope(scope))) if state: - params.append(('state', state)) + params.append(("state", state)) - for k in kwargs: - if kwargs[k]: - params.append((to_unicode(k), kwargs[k])) + for k, value in kwargs.items(): + if value is not None: + if isinstance(value, (list, tuple)): + for v in value: + if v is not None: + params.append((to_unicode(k), v)) + else: + params.append((to_unicode(k), value)) return add_params_to_uri(uri, params) -def prepare_token_request(grant_type, body='', redirect_uri=None, **kwargs): +def prepare_token_request(grant_type, body="", redirect_uri=None, **kwargs): """Prepare the access token request. Per `Section 4.1.3`_. The client makes a request to the token endpoint by adding the @@ -89,15 +89,15 @@ def prepare_token_request(grant_type, body='', redirect_uri=None, **kwargs): .. _`Section 4.1.1`: https://tools.ietf.org/html/rfc6749#section-4.1.1 .. _`Section 4.1.3`: https://tools.ietf.org/html/rfc6749#section-4.1.3 """ - params = [('grant_type', grant_type)] + params = [("grant_type", grant_type)] if redirect_uri: - params.append(('redirect_uri', redirect_uri)) + params.append(("redirect_uri", redirect_uri)) - if 'scope' in kwargs: - kwargs['scope'] = list_to_scope(kwargs['scope']) + if "scope" in kwargs: + kwargs["scope"] = list_to_scope(kwargs["scope"]) - if grant_type == 'authorization_code' and 'code' not in kwargs: + if grant_type == "authorization_code" and kwargs.get("code") is None: raise MissingCodeException() for k in kwargs: @@ -148,10 +148,11 @@ def parse_authorization_code_response(uri, state=None): query = urlparse.urlparse(uri).query params = dict(urlparse.parse_qsl(query)) - if 'code' not in params: + if "code" not in params: raise MissingCodeException() - if state and params.get('state', None) != state: + params_state = params.get("state") + if state and params_state != state: raise MismatchingStateException() return params @@ -201,13 +202,13 @@ def parse_implicit_response(uri, state=None): fragment = urlparse.urlparse(uri).fragment params = dict(urlparse.parse_qsl(fragment, keep_blank_values=True)) - if 'access_token' not in params: + if "access_token" not in params: raise MissingTokenException() - if 'token_type' not in params: + if "token_type" not in params: raise MissingTokenTypeException() - if state and params.get('state', None) != state: + if state and params.get("state", None) != state: raise MismatchingStateException() return params diff --git a/authlib/oauth2/rfc6749/requests.py b/authlib/oauth2/rfc6749/requests.py new file mode 100644 index 000000000..17994c500 --- /dev/null +++ b/authlib/oauth2/rfc6749/requests.py @@ -0,0 +1,199 @@ +from collections import defaultdict + +from authlib.deprecate import deprecate + +from .errors import InsecureTransportError + + +class OAuth2Payload: + @property + def data(self): + raise NotImplementedError() + + @property + def datalist(self) -> defaultdict[str, list]: + raise NotImplementedError() + + @property + def client_id(self) -> str: + """The authorization server issues the registered client a client + identifier -- a unique string representing the registration + information provided by the client. The value is extracted from + request. + + :return: string + """ + return self.data.get("client_id") + + @property + def response_type(self) -> str: + rt = self.data.get("response_type") + if rt and " " in rt: + # sort multiple response types + return " ".join(sorted(rt.split())) + return rt + + @property + def grant_type(self) -> str: + return self.data.get("grant_type") + + @property + def redirect_uri(self): + return self.data.get("redirect_uri") + + @property + def scope(self) -> str: + return self.data.get("scope") + + @property + def state(self): + return self.data.get("state") + + +class BasicOAuth2Payload(OAuth2Payload): + def __init__(self, payload): + self._data = payload + self._datalist = {key: [value] for key, value in payload.items()} + + @property + def data(self): + return self._data + + @property + def datalist(self) -> defaultdict[str, list]: + return self._datalist + + +class OAuth2Request(OAuth2Payload): + def __init__(self, method: str, uri: str, body=None, headers=None): + InsecureTransportError.check(uri) + #: HTTP method + self.method = method + self.uri = uri + #: HTTP headers + self.headers = headers or {} + + # Store body for backward compatibility but issue deprecation warning if used + if body is not None: + deprecate( + "'body' parameter in OAuth2Request is deprecated. " + "Use the payload system instead.", + version="1.8", + ) + self._body = body + + self.payload = None + + self.client = None + self.auth_method = None + self.user = None + self.authorization_code = None + self.refresh_token = None + self.credential = None + self._scope = None + + @property + def args(self): + raise NotImplementedError() + + @property + def form(self): + if self._body: + return self._body + raise NotImplementedError() + + @property + def data(self): + deprecate( + "'request.data' is deprecated in favor of 'request.payload.data'", + version="1.8", + ) + return self.payload.data + + @property + def datalist(self) -> defaultdict[str, list]: + deprecate( + "'request.datalist' is deprecated in favor of 'request.payload.datalist'", + version="1.8", + ) + return self.payload.datalist + + @property + def client_id(self) -> str: + deprecate( + "'request.client_id' is deprecated in favor of 'request.payload.client_id'", + version="1.8", + ) + return self.payload.client_id + + @property + def response_type(self) -> str: + deprecate( + "'request.response_type' is deprecated in favor of 'request.payload.response_type'", + version="1.8", + ) + return self.payload.response_type + + @property + def grant_type(self) -> str: + deprecate( + "'request.grant_type' is deprecated in favor of 'request.payload.grant_type'", + version="1.8", + ) + return self.payload.grant_type + + @property + def redirect_uri(self): + deprecate( + "'request.redirect_uri' is deprecated in favor of 'request.payload.redirect_uri'", + version="1.8", + ) + return self.payload.redirect_uri + + @property + def scope(self) -> str: + if self._scope is not None: + return self._scope + return self.payload.scope + + @scope.setter + def scope(self, value: str): + self._scope = value + + @property + def state(self): + deprecate( + "'request.state' is deprecated in favor of 'request.payload.state'", + version="1.8", + ) + return self.payload.state + + @property + def body(self): + deprecate( + "'request.body' is deprecated. Use the payload system instead.", + version="1.8", + ) + return self._body + + +class JsonPayload: + @property + def data(self): + raise NotImplementedError() + + +class JsonRequest: + def __init__(self, method, uri, headers=None): + self.method = method + self.uri = uri + self.headers = headers or {} + self.payload = None + + @property + def data(self): + deprecate( + "'request.data' is deprecated in favor of 'request.payload.data'", + version="1.8", + ) + return self.payload.data diff --git a/authlib/oauth2/rfc6749/resource_protector.py b/authlib/oauth2/rfc6749/resource_protector.py index 40567950c..11436205e 100644 --- a/authlib/oauth2/rfc6749/resource_protector.py +++ b/authlib/oauth2/rfc6749/resource_protector.py @@ -1,37 +1,148 @@ -""" - authlib.oauth2.rfc6749.resource_protector - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6749.resource_protector. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Implementation of Accessing Protected Resources per `Section 7`_. +Implementation of Accessing Protected Resources per `Section 7`_. - .. _`Section 7`: https://tools.ietf.org/html/rfc6749#section-7 +.. _`Section 7`: https://tools.ietf.org/html/rfc6749#section-7 """ -from .errors import MissingAuthorizationError, UnsupportedTokenTypeError +from .errors import MissingAuthorizationError +from .errors import UnsupportedTokenTypeError +from .util import scope_to_list + + +class TokenValidator: + """Base token validator class. Subclass this validator to register + into ResourceProtector instance. + """ + + TOKEN_TYPE = "bearer" + + def __init__(self, realm=None, **extra_attributes): + self.realm = realm + self.extra_attributes = extra_attributes + + @staticmethod + def scope_insufficient(token_scopes, required_scopes): + if not required_scopes: + return False + + token_scopes = scope_to_list(token_scopes) + if not token_scopes: + return True + + token_scopes = set(token_scopes) + for scope in required_scopes: + resource_scopes = set(scope_to_list(scope)) + if token_scopes.issuperset(resource_scopes): + return False + + return True + + def authenticate_token(self, token_string): + """A method to query token from database with the given token string. + Developers MUST re-implement this method. For instance:: + + def authenticate_token(self, token_string): + return get_token_from_database(token_string) + + :param token_string: A string to represent the access_token. + :return: token + """ + raise NotImplementedError() + + def validate_request(self, request): + """A method to validate if the HTTP request is valid or not. Developers MUST + re-implement this method. For instance, your server requires a + "X-Device-Version" in the header:: + + def validate_request(self, request): + if "X-Device-Version" not in request.headers: + raise InvalidRequestError() + Usually, you don't have to detect if the request is valid or not. If you have + to, you MUST re-implement this method. -class ResourceProtector(object): + :param request: instance of HttpRequest + :raise: InvalidRequestError + """ + + def validate_token(self, token, scopes, request): + """A method to validate if the authorized token is valid, if it has the + permission on the given scopes. Developers MUST re-implement this method. + e.g, check if token is expired, revoked:: + + def validate_token(self, token, scopes, request): + if not token: + raise InvalidTokenError() + if token.is_expired() or token.is_revoked(): + raise InvalidTokenError() + if not match_token_scopes(token, scopes): + raise InsufficientScopeError() + """ + raise NotImplementedError() + + +class ResourceProtector: def __init__(self): self._token_validators = {} + self._default_realm = None + self._default_auth_type = None + + def register_token_validator(self, validator: TokenValidator): + """Register a token validator for a given Authorization type. + Authlib has a built-in BearerTokenValidator per rfc6750. + """ + if not self._default_auth_type: + self._default_realm = validator.realm + self._default_auth_type = validator.TOKEN_TYPE - def register_token_validator(self, validator): if validator.TOKEN_TYPE not in self._token_validators: self._token_validators[validator.TOKEN_TYPE] = validator - def validate_request(self, scope, request, scope_operator='AND'): - auth = request.headers.get('Authorization') + def get_token_validator(self, token_type): + """Get token validator from registry for the given token type.""" + validator = self._token_validators.get(token_type.lower()) + if not validator: + raise UnsupportedTokenTypeError( + self._default_auth_type, self._default_realm + ) + return validator + + def parse_request_authorization(self, request): + """Parse the token and token validator from request Authorization header. + Here is an example of Authorization header:: + + Authorization: Bearer a-token-string + + This method will parse this header, if it can find the validator for + ``Bearer``, it will return the validator and ``a-token-string``. + + :return: validator, token_string + :raise: MissingAuthorizationError + :raise: UnsupportedTokenTypeError + """ + auth = request.headers.get("Authorization") if not auth: - raise MissingAuthorizationError() + raise MissingAuthorizationError( + self._default_auth_type, self._default_realm + ) # https://tools.ietf.org/html/rfc6749#section-7.1 token_parts = auth.split(None, 1) if len(token_parts) != 2: - raise UnsupportedTokenTypeError() + raise UnsupportedTokenTypeError( + self._default_auth_type, self._default_realm + ) token_type, token_string = token_parts + validator = self.get_token_validator(token_type) + return validator, token_string - validator = self._token_validators.get(token_type.lower()) - if not validator: - raise UnsupportedTokenTypeError() - - return validator(token_string, scope, request, scope_operator) + def validate_request(self, scopes, request, **kwargs): + """Validate the request and return a token.""" + validator, token_string = self.parse_request_authorization(request) + validator.validate_request(request) + token = validator.authenticate_token(token_string) + validator.validate_token(token, scopes, request, **kwargs) + return token diff --git a/authlib/oauth2/rfc6749/token_endpoint.py b/authlib/oauth2/rfc6749/token_endpoint.py index 726f7e0fb..377b9e32d 100644 --- a/authlib/oauth2/rfc6749/token_endpoint.py +++ b/authlib/oauth2/rfc6749/token_endpoint.py @@ -1,34 +1,30 @@ -class TokenEndpoint(object): - #: Endpoint name to be registered - ENDPOINT_NAME = None - #: Supported token types - SUPPORTED_TOKEN_TYPES = ('access_token', 'refresh_token') - #: Allowed client authenticate methods - CLIENT_AUTH_METHODS = ['client_secret_basic'] +from .endpoint import Endpoint + - def __init__(self, server): - self.server = server +class TokenEndpoint(Endpoint): + """Base class for token-based endpoints (revocation, introspection). - def __call__(self, request): - # make it callable for authorization server - # ``create_endpoint_response`` - return self.create_endpoint_response(request) + Subclasses must implement :meth:`authenticate_token` and + :meth:`create_endpoint_response`. + """ - def create_endpoint_request(self, request): - return self.server.create_oauth2_request(request) + #: Supported token types + SUPPORTED_TOKEN_TYPES = ("access_token", "refresh_token") + #: Allowed client authenticate methods + CLIENT_AUTH_METHODS = ["client_secret_basic"] def authenticate_endpoint_client(self, request): - """Authentication client for endpoint with ``CLIENT_AUTH_METHODS``. - """ + """Authenticate client for endpoint with ``CLIENT_AUTH_METHODS``.""" client = self.server.authenticate_client( - request=request, - methods=self.CLIENT_AUTH_METHODS, + request, self.CLIENT_AUTH_METHODS, self.ENDPOINT_NAME ) request.client = client return client - def authenticate_endpoint_credential(self, request, client): + def authenticate_token(self, request, client): + """Authenticate and return the token. Subclasses must implement this.""" raise NotImplementedError() def create_endpoint_response(self, request): + """Process the request and return response. Subclasses must implement this.""" raise NotImplementedError() diff --git a/authlib/oauth2/rfc6749/util.py b/authlib/oauth2/rfc6749/util.py index a216fbf34..93199245f 100644 --- a/authlib/oauth2/rfc6749/util.py +++ b/authlib/oauth2/rfc6749/util.py @@ -1,5 +1,7 @@ import base64 import binascii +from urllib.parse import unquote + from authlib.common.encoding import to_unicode @@ -22,19 +24,19 @@ def scope_to_list(scope): def extract_basic_authorization(headers): - auth = headers.get('Authorization') - if not auth or ' ' not in auth: + auth = headers.get("Authorization") + if not auth or " " not in auth: return None, None auth_type, auth_token = auth.split(None, 1) - if auth_type.lower() != 'basic': + if auth_type.lower() != "basic": return None, None try: query = to_unicode(base64.b64decode(auth_token)) except (binascii.Error, TypeError): return None, None - if ':' in query: - username, password = query.split(':', 1) - return username, password + if ":" in query: + username, password = query.split(":", 1) + return unquote(username), unquote(password) return query, None diff --git a/authlib/oauth2/rfc6749/wrappers.py b/authlib/oauth2/rfc6749/wrappers.py index a1f454312..ae3726f5a 100644 --- a/authlib/oauth2/rfc6749/wrappers.py +++ b/authlib/oauth2/rfc6749/wrappers.py @@ -1,96 +1,35 @@ import time -from authlib.common.urls import urlparse, url_decode -from .errors import InsecureTransportError class OAuth2Token(dict): def __init__(self, params): - if params.get('expires_at'): - params['expires_at'] = int(params['expires_at']) - elif params.get('expires_in'): - params['expires_at'] = int(time.time()) + \ - int(params['expires_in']) - super(OAuth2Token, self).__init__(params) - - def is_expired(self): - expires_at = self.get('expires_at') - if not expires_at: + if params.get("expires_at") is not None: + try: + params["expires_at"] = int(params["expires_at"]) + except ValueError: + # If expires_at is not parseable, fall back to expires_in if available + # Otherwise leave expires_at untouched + if params.get("expires_in"): + params["expires_at"] = int(time.time()) + int(params["expires_in"]) + + elif params.get("expires_in"): + params["expires_at"] = int(time.time()) + int(params["expires_in"]) + + super().__init__(params) + + def is_expired(self, leeway=60): + expires_at = self.get("expires_at") + if expires_at is None: + return None + # Only check expiration if expires_at is an integer + if not isinstance(expires_at, int): return None - return expires_at < time.time() + # small timedelta to consider token as expired before it actually expires + expiration_threshold = expires_at - leeway + return expiration_threshold < time.time() @classmethod def from_dict(cls, token): if isinstance(token, dict) and not isinstance(token, cls): token = cls(token) return token - - -class OAuth2Request(object): - def __init__(self, method, uri, body=None, headers=None): - InsecureTransportError.check(uri) - #: HTTP method - self.method = method - self.uri = uri - self.body = body - #: HTTP headers - self.headers = headers or {} - - self.query = urlparse.urlparse(uri).query - - self.args = dict(url_decode(self.query)) - self.form = self.body or {} - - #: dict of query and body params - data = {} - data.update(self.args) - data.update(self.form) - self.data = data - - #: authenticate method - self.auth_method = None - #: authenticated user on this request - self.user = None - #: authorization_code or token model instance - self.credential = None - #: client which sending this request - self.client = None - - @property - def client_id(self): - """The authorization server issues the registered client a client - identifier -- a unique string representing the registration - information provided by the client. The value is extracted from - request. - - :return: string - """ - return self.data.get('client_id') - - @property - def response_type(self): - return self.data.get('response_type') - - @property - def grant_type(self): - return self.data.get('grant_type') - - @property - def redirect_uri(self): - return self.data.get('redirect_uri') - - @property - def scope(self): - return self.data.get('scope') - - @property - def state(self): - return self.data.get('state') - - -class HttpRequest(object): - def __init__(self, method, uri, data=None, headers=None): - self.method = method - self.uri = uri - self.data = data - self.headers = headers or {} - self.user = None diff --git a/authlib/oauth2/rfc6750/__init__.py b/authlib/oauth2/rfc6750/__init__.py index 4ad021269..f7878b59e 100644 --- a/authlib/oauth2/rfc6750/__init__.py +++ b/authlib/oauth2/rfc6750/__init__.py @@ -1,23 +1,27 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc6750 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6750. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - The OAuth 2.0 Authorization Framework: Bearer Token Usage. +This module represents a direct implementation of +The OAuth 2.0 Authorization Framework: Bearer Token Usage. - https://tools.ietf.org/html/rfc6750 +https://tools.ietf.org/html/rfc6750 """ -from .errors import InvalidRequestError, InvalidTokenError, InsufficientScopeError +from .errors import InsufficientScopeError +from .errors import InvalidTokenError from .parameters import add_bearer_token -from .wrappers import BearerToken +from .token import BearerTokenGenerator from .validator import BearerTokenValidator +# TODO: add deprecation +BearerToken = BearerTokenGenerator + __all__ = [ - 'InvalidRequestError', 'InvalidTokenError', 'InsufficientScopeError', - 'add_bearer_token', - 'BearerToken', - 'BearerTokenValidator', + "InvalidTokenError", + "InsufficientScopeError", + "add_bearer_token", + "BearerToken", + "BearerTokenGenerator", + "BearerTokenValidator", ] diff --git a/authlib/oauth2/rfc6750/errors.py b/authlib/oauth2/rfc6750/errors.py index 543fa9a54..c897616b4 100644 --- a/authlib/oauth2/rfc6750/errors.py +++ b/authlib/oauth2/rfc6750/errors.py @@ -1,22 +1,19 @@ -""" - authlib.rfc6750.errors - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.rfc6750.errors. +~~~~~~~~~~~~~~~~~~~~~~ - OAuth Extensions Error Registration. When a request fails, - the resource server responds using the appropriate HTTP - status code and includes one of the following error codes - in the response. +OAuth Extensions Error Registration. When a request fails, +the resource server responds using the appropriate HTTP +status code and includes one of the following error codes +in the response. - https://tools.ietf.org/html/rfc6750#section-6.2 +https://tools.ietf.org/html/rfc6750#section-6.2 - :copyright: (c) 2017 by Hsiaoming Yang. +:copyright: (c) 2017 by Hsiaoming Yang. """ + from ..base import OAuth2Error -from ..rfc6749.errors import InvalidRequestError -__all__ = [ - 'InvalidRequestError', 'InvalidTokenError', 'InsufficientScopeError' -] +__all__ = ["InvalidTokenError", "InsufficientScopeError"] class InvalidTokenError(OAuth2Error): @@ -28,20 +25,26 @@ class InvalidTokenError(OAuth2Error): https://tools.ietf.org/html/rfc6750#section-3.1 """ - error = 'invalid_token' + + error = "invalid_token" + description = ( + "The access token provided is expired, revoked, malformed, " + "or invalid for other reasons." + ) status_code = 401 - def __init__(self, description=None, uri=None, status_code=None, - state=None, realm=None): - super(InvalidTokenError, self).__init__( - description, uri, status_code, state) + def __init__( + self, + description=None, + uri=None, + status_code=None, + state=None, + realm=None, + extra_attributes=None, + ): + super().__init__(description, uri, status_code, state) self.realm = realm - - def get_error_description(self): - return self.gettext( - 'The access token provided is expired, revoked, malformed, ' - 'or invalid for other reasons.' - ) + self.extra_attributes = extra_attributes or {} def get_headers(self): """If the protected resource request does not include authentication @@ -52,17 +55,19 @@ def get_headers(self): https://tools.ietf.org/html/rfc6750#section-3 """ - headers = super(InvalidTokenError, self).get_headers() + headers = super().get_headers() extras = [] if self.realm: - extras.append('realm="{}"'.format(self.realm)) - extras.append('error="{}"'.format(self.error)) + extras.append(f'realm="{self.realm}"') + if self.extra_attributes: + extras.extend( + [f'{k}="{self.extra_attributes[k]}"' for k in self.extra_attributes] + ) + extras.append(f'error="{self.error}"') error_description = self.get_error_description() - extras.append('error_description="{}"'.format(error_description)) - headers.append( - ('WWW-Authenticate', 'Bearer ' + ', '.join(extras)) - ) + extras.append(f'error_description="{error_description}"') + headers.append(("WWW-Authenticate", "Bearer " + ", ".join(extras))) return headers @@ -75,11 +80,9 @@ class InsufficientScopeError(OAuth2Error): https://tools.ietf.org/html/rfc6750#section-3.1 """ - error = 'insufficient_scope' - status_code = 403 - def get_error_description(self): - return self.gettext( - 'The request requires higher privileges than ' - 'provided by the access token.' - ) + error = "insufficient_scope" + description = ( + "The request requires higher privileges than provided by the access token." + ) + status_code = 403 diff --git a/authlib/oauth2/rfc6750/parameters.py b/authlib/oauth2/rfc6750/parameters.py index 5f4e1006e..6bb94f92a 100644 --- a/authlib/oauth2/rfc6750/parameters.py +++ b/authlib/oauth2/rfc6750/parameters.py @@ -1,4 +1,5 @@ -from authlib.common.urls import add_params_to_qs, add_params_to_uri +from authlib.common.urls import add_params_to_qs +from authlib.common.urls import add_params_to_uri def add_to_uri(token, uri): @@ -7,7 +8,7 @@ def add_to_uri(token, uri): http://www.example.com/path?access_token=h480djs93hd8 """ - return add_params_to_uri(uri, [('access_token', token)]) + return add_params_to_uri(uri, [("access_token", token)]) def add_to_headers(token, headers=None): @@ -17,7 +18,7 @@ def add_to_headers(token, headers=None): Authorization: Bearer h480djs93hd8 """ headers = headers or {} - headers['Authorization'] = 'Bearer {}'.format(token) + headers["Authorization"] = f"Bearer {token}" return headers @@ -27,15 +28,15 @@ def add_to_body(token, body=None): access_token=h480djs93hd8 """ if body is None: - body = '' - return add_params_to_qs(body, [('access_token', token)]) + body = "" + return add_params_to_qs(body, [("access_token", token)]) -def add_bearer_token(token, uri, headers, body, placement='header'): - if placement in ('uri', 'url', 'query'): +def add_bearer_token(token, uri, headers, body, placement="header"): + if placement in ("uri", "url", "query"): uri = add_to_uri(token, uri) - elif placement in ('header', 'headers'): + elif placement in ("header", "headers"): headers = add_to_headers(token, headers) - elif placement == 'body': + elif placement == "body": body = add_to_body(token, body) return uri, headers, body diff --git a/authlib/oauth2/rfc6750/token.py b/authlib/oauth2/rfc6750/token.py new file mode 100644 index 000000000..d73db2b50 --- /dev/null +++ b/authlib/oauth2/rfc6750/token.py @@ -0,0 +1,125 @@ +from ..rfc6749.errors import InvalidScopeError + + +class BearerTokenGenerator: + """Bearer token generator which can create the payload for token response + by OAuth 2 server. A typical token response would be: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json;charset=UTF-8 + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"mF_9.B5f-4.1JqM", + "token_type":"Bearer", + "expires_in":3600, + "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" + } + """ + + #: default expires_in value + DEFAULT_EXPIRES_IN = 3600 + #: default expires_in value differentiate by grant_type + GRANT_TYPES_EXPIRES_IN = { + "authorization_code": 864000, + "implicit": 3600, + "password": 864000, + "client_credentials": 864000, + } + + def __init__( + self, + access_token_generator, + refresh_token_generator=None, + expires_generator=None, + ): + self.access_token_generator = access_token_generator + self.refresh_token_generator = refresh_token_generator + self.expires_generator = expires_generator + + def _get_expires_in(self, client, grant_type): + if self.expires_generator is None: + expires_in = self.GRANT_TYPES_EXPIRES_IN.get( + grant_type, self.DEFAULT_EXPIRES_IN + ) + elif callable(self.expires_generator): + expires_in = self.expires_generator(client, grant_type) + elif isinstance(self.expires_generator, int): + expires_in = self.expires_generator + else: + expires_in = self.DEFAULT_EXPIRES_IN + return expires_in + + @staticmethod + def get_allowed_scope(client, scope): + """Get the allowed scope for token generation. + + Per RFC 6749 Section 3.3, if the client omits the scope parameter, + the authorization server MUST either process the request using a + pre-defined default value or fail the request indicating an invalid scope. + + :param client: the client making the request + :param scope: the requested scope (may be None if omitted) + :return: the allowed scope string + :raises InvalidScopeError: if client.get_allowed_scope returns None + """ + scope = client.get_allowed_scope(scope) + if scope is None: + raise InvalidScopeError() + return scope + + def generate( + self, + grant_type, + client, + user=None, + scope=None, + expires_in=None, + include_refresh_token=True, + ): + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :param include_refresh_token: should refresh_token be included. + :return: Token dict + """ + scope = self.get_allowed_scope(client, scope) + access_token = self.access_token_generator( + client=client, grant_type=grant_type, user=user, scope=scope + ) + if expires_in is None: + expires_in = self._get_expires_in(client, grant_type) + + token = { + "token_type": "Bearer", + "access_token": access_token, + } + if expires_in: + token["expires_in"] = expires_in + if include_refresh_token and self.refresh_token_generator: + token["refresh_token"] = self.refresh_token_generator( + client=client, grant_type=grant_type, user=user, scope=scope + ) + if scope: + token["scope"] = scope + return token + + def __call__( + self, + grant_type, + client, + user=None, + scope=None, + expires_in=None, + include_refresh_token=True, + ): + return self.generate( + grant_type, client, user, scope, expires_in, include_refresh_token + ) diff --git a/authlib/oauth2/rfc6750/validator.py b/authlib/oauth2/rfc6750/validator.py index 31467aa68..a9716ec5f 100644 --- a/authlib/oauth2/rfc6750/validator.py +++ b/authlib/oauth2/rfc6750/validator.py @@ -1,24 +1,16 @@ -""" - authlib.oauth2.rfc6750.validator - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc6750.validator. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Validate Bearer Token for in request, scope and token. +Validate Bearer Token for in request, scope and token. """ -import time -from ..rfc6749.util import scope_to_list -from .errors import ( - InvalidRequestError, - InvalidTokenError, - InsufficientScopeError -) - +from ..rfc6749 import TokenValidator +from .errors import InsufficientScopeError +from .errors import InvalidTokenError -class BearerTokenValidator(object): - TOKEN_TYPE = 'bearer' - def __init__(self, realm=None): - self.realm = realm +class BearerTokenValidator(TokenValidator): + TOKEN_TYPE = "bearer" def authenticate_token(self, token_string): """A method to query token from database with the given token string. @@ -32,68 +24,19 @@ def authenticate_token(self, token_string): """ raise NotImplementedError() - def request_invalid(self, request): - """Check if the HTTP request is valid or not. Developers MUST - re-implement this method. For instance, your server requires a - "X-Device-Version" in the header:: - - def request_invalid(self, request): - return 'X-Device-Version' in request.headers - - Usually, you don't have to detect if the request is valid or not, - you can just return a ``False``. - - :param request: instance of HttpRequest - :return: Boolean - """ - raise NotImplementedError() - - def token_revoked(self, token): - """Check if this token is revoked. Developers MUST re-implement this - method. If there is a column called ``revoked`` on the token table:: - - def token_revoked(self, token): - return token.revoked - - :param token: token instance - :return: Boolean - """ - raise NotImplementedError() - - def token_expired(self, token): - expires_at = token.get_expires_at() - if not expires_at: - return False - return expires_at < time.time() - - def scope_insufficient(self, token, scope, operator='AND'): - if not scope: - return False - - token_scopes = scope_to_list(token.get_scope()) - if not token_scopes: - return True - - token_scopes = set(token_scopes) - resource_scopes = set(scope_to_list(scope)) - if operator == 'AND': - return not token_scopes.issuperset(resource_scopes) - if operator == 'OR': - return not token_scopes & resource_scopes - if callable(operator): - return not operator(token_scopes, resource_scopes) - raise ValueError('Invalid operator value') - - def __call__(self, token_string, scope, request, scope_operator='AND'): - if self.request_invalid(request): - raise InvalidRequestError() - token = self.authenticate_token(token_string) + def validate_token(self, token, scopes, request): + """Check if token is active and matches the requested scopes.""" if not token: - raise InvalidTokenError(realm=self.realm) - if self.token_expired(token): - raise InvalidTokenError(realm=self.realm) - if self.token_revoked(token): - raise InvalidTokenError(realm=self.realm) - if self.scope_insufficient(token, scope, scope_operator): + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) + if token.is_expired(): + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) + if token.is_revoked(): + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) + if self.scope_insufficient(token.get_scope(), scopes): raise InsufficientScopeError() - return token diff --git a/authlib/oauth2/rfc6750/wrappers.py b/authlib/oauth2/rfc6750/wrappers.py deleted file mode 100644 index 9e2c226c3..000000000 --- a/authlib/oauth2/rfc6750/wrappers.py +++ /dev/null @@ -1,97 +0,0 @@ - -class BearerToken(object): - """Bearer Token generator which can create the payload for token response - by OAuth 2 server. A typical token response would be: - - .. code-block:: http - - HTTP/1.1 200 OK - Content-Type: application/json;charset=UTF-8 - Cache-Control: no-store - Pragma: no-cache - - { - "access_token":"mF_9.B5f-4.1JqM", - "token_type":"Bearer", - "expires_in":3600, - "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" - } - - :param access_token_generator: a function to generate access_token. - :param refresh_token_generator: a function to generate refresh_token, - if not provided, refresh_token will not be added into token response. - :param expires_generator: The expires_generator can be an int value or a - function. If it is int, all token expires_in will be this value. If it - is function, it can generate expires_in depending on client and - grant_type:: - - def expires_generator(client, grant_type): - if is_official_client(client): - return 3600 * 1000 - if grant_type == 'implicit': - return 3600 - return 3600 * 10 - :return: Callable - - When BearerToken is initialized, it will be callable:: - - token_generator = BearerToken(access_token_generator) - token = token_generator(client, grant_type, expires_in=None, - scope=None, include_refresh_token=True) - - The callable function that BearerToken created accepts these parameters: - - :param client: the client that making the request. - :param grant_type: current requested grant_type. - :param expires_in: if provided, use this value as expires_in. - :param scope: current requested scope. - :param include_refresh_token: should refresh_token be included. - :return: Token dict - """ - - #: default expires_in value - DEFAULT_EXPIRES_IN = 3600 - #: default expires_in value differentiate by grant_type - GRANT_TYPES_EXPIRES_IN = { - 'authorization_code': 864000, - 'implicit': 3600, - 'password': 864000, - 'client_credentials': 864000 - } - - def __init__(self, access_token_generator, - refresh_token_generator=None, - expires_generator=None): - self.access_token_generator = access_token_generator - self.refresh_token_generator = refresh_token_generator - self.expires_generator = expires_generator - - def _get_expires_in(self, client, grant_type): - if self.expires_generator is None: - expires_in = self.GRANT_TYPES_EXPIRES_IN.get( - grant_type, self.DEFAULT_EXPIRES_IN) - elif callable(self.expires_generator): - expires_in = self.expires_generator(client, grant_type) - elif isinstance(self.expires_generator, int): - expires_in = self.expires_generator - else: - expires_in = self.DEFAULT_EXPIRES_IN - return expires_in - - def __call__(self, client, grant_type, user=None, scope=None, - expires_in=None, include_refresh_token=True): - access_token = self.access_token_generator(client, grant_type, user, scope) - if expires_in is None: - expires_in = self._get_expires_in(client, grant_type) - - token = { - 'token_type': 'Bearer', - 'access_token': access_token, - 'expires_in': expires_in - } - if include_refresh_token and self.refresh_token_generator: - token['refresh_token'] = self.refresh_token_generator( - client, grant_type, user, scope) - if scope: - token['scope'] = scope - return token diff --git a/authlib/oauth2/rfc7009/__init__.py b/authlib/oauth2/rfc7009/__init__.py index 0b8bc7f2e..c355a19c4 100644 --- a/authlib/oauth2/rfc7009/__init__.py +++ b/authlib/oauth2/rfc7009/__init__.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc7009 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc7009. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - OAuth 2.0 Token Revocation. +This module represents a direct implementation of +OAuth 2.0 Token Revocation. - https://tools.ietf.org/html/rfc7009 +https://tools.ietf.org/html/rfc7009 """ from .parameters import prepare_revoke_token_request from .revocation import RevocationEndpoint -__all__ = ['prepare_revoke_token_request', 'RevocationEndpoint'] +__all__ = ["prepare_revoke_token_request", "RevocationEndpoint"] diff --git a/authlib/oauth2/rfc7009/parameters.py b/authlib/oauth2/rfc7009/parameters.py index 2a829a752..dbbe2db7f 100644 --- a/authlib/oauth2/rfc7009/parameters.py +++ b/authlib/oauth2/rfc7009/parameters.py @@ -1,8 +1,7 @@ from authlib.common.urls import add_params_to_qs -def prepare_revoke_token_request(token, token_type_hint=None, - body=None, headers=None): +def prepare_revoke_token_request(token, token_type_hint=None, body=None, headers=None): """Construct request body and headers for revocation endpoint. :param token: access_token or refresh_token string. @@ -13,13 +12,13 @@ def prepare_revoke_token_request(token, token_type_hint=None, https://tools.ietf.org/html/rfc7009#section-2.1 """ - params = [('token', token)] + params = [("token", token)] if token_type_hint: - params.append(('token_type_hint', token_type_hint)) + params.append(("token_type_hint", token_type_hint)) - body = add_params_to_qs(body or '', params) + body = add_params_to_qs(body or "", params) if headers is None: headers = {} - headers['Content-Type'] = 'application/x-www-form-urlencoded' + headers["Content-Type"] = "application/x-www-form-urlencoded" return body, headers diff --git a/authlib/oauth2/rfc7009/revocation.py b/authlib/oauth2/rfc7009/revocation.py index aafdd37d4..0dd85d081 100644 --- a/authlib/oauth2/rfc7009/revocation.py +++ b/authlib/oauth2/rfc7009/revocation.py @@ -1,9 +1,9 @@ from authlib.consts import default_json_headers + +from ..rfc6749 import InvalidGrantError +from ..rfc6749 import InvalidRequestError from ..rfc6749 import TokenEndpoint -from ..rfc6749 import ( - InvalidRequestError, - UnsupportedTokenTypeError, -) +from ..rfc6749 import UnsupportedTokenTypeError class RevocationEndpoint(TokenEndpoint): @@ -12,10 +12,11 @@ class RevocationEndpoint(TokenEndpoint): .. _RFC7009: https://tools.ietf.org/html/rfc7009 """ + #: Endpoint name to be registered - ENDPOINT_NAME = 'revocation' + ENDPOINT_NAME = "revocation" - def authenticate_endpoint_credential(self, request, client): + def authenticate_token(self, request, client): """The client constructs the request by including the following parameters using the "application/x-www-form-urlencoded" format in the HTTP request entity-body: @@ -27,13 +28,21 @@ def authenticate_endpoint_credential(self, request, client): OPTIONAL. A hint about the type of the token submitted for revocation. """ - if 'token' not in request.form: + self.check_params(request, client) + token = self.query_token( + request.form["token"], request.form.get("token_type_hint") + ) + if token and not token.check_client(client): + raise InvalidGrantError() + return token + + def check_params(self, request, client): + if "token" not in request.form: raise InvalidRequestError() - token_type = request.form.get('token_type_hint') - if token_type and token_type not in self.SUPPORTED_TOKEN_TYPES: + hint = request.form.get("token_type_hint") + if hint and hint not in self.SUPPORTED_TOKEN_TYPES: raise UnsupportedTokenTypeError() - return self.query_token(request.form['token'], token_type, client) def create_endpoint_response(self, request): """Validate revocation request and create the response for revocation. @@ -54,33 +63,33 @@ def create_endpoint_response(self, request): # then verifies whether the token was issued to the client making # the revocation request - credential = self.authenticate_endpoint_credential(request, client) + token = self.authenticate_token(request, client) # the authorization server invalidates the token - if credential: - self.revoke_token(credential) + if token: + self.revoke_token(token, request) self.server.send_signal( - 'after_revoke_token', - token=credential, + "after_revoke_token", + token=token, client=client, ) return 200, {}, default_json_headers - def query_token(self, token, token_type_hint, client): + def query_token(self, token_string, token_type_hint): """Get the token from database/storage by the given token string. Developers should implement this method:: - def query_token(self, token, token_type_hint, client): + def query_token(self, token_string, token_type_hint): if token_type_hint == 'access_token': - return Token.query_by_access_token(token, client.client_id) + return Token.query_by_access_token(token_string) if token_type_hint == 'refresh_token': - return Token.query_by_refresh_token(token, client.client_id) - return Token.query_by_access_token(token, client.client_id) or \ - Token.query_by_refresh_token(token, client.client_id) + return Token.query_by_refresh_token(token_string) + return Token.query_by_access_token(token_string) or \ + Token.query_by_refresh_token(token_string) """ raise NotImplementedError() - def revoke_token(self, token): + def revoke_token(self, token, request): """Mark token as revoked. Since token MUST be unique, it would be dangerous to delete it. Consider this situation: @@ -91,8 +100,13 @@ def revoke_token(self, token): It would be secure to mark a token as revoked:: - def revoke_token(self, token): - token.revoked = True + def revoke_token(self, token, request): + hint = request.form.get("token_type_hint") + if hint == "access_token": + token.access_token_revoked = True + else: + token.access_token_revoked = True + token.refresh_token_revoked = True token.save() """ raise NotImplementedError() diff --git a/authlib/oauth2/rfc7521/__init__.py b/authlib/oauth2/rfc7521/__init__.py index 0dbe0b30b..86e57652a 100644 --- a/authlib/oauth2/rfc7521/__init__.py +++ b/authlib/oauth2/rfc7521/__init__.py @@ -1,3 +1,3 @@ from .client import AssertionClient -__all__ = ['AssertionClient'] +__all__ = ["AssertionClient"] diff --git a/authlib/oauth2/rfc7521/client.py b/authlib/oauth2/rfc7521/client.py index d1b98ba55..decbd1306 100644 --- a/authlib/oauth2/rfc7521/client.py +++ b/authlib/oauth2/rfc7521/client.py @@ -2,20 +2,32 @@ from authlib.oauth2.base import OAuth2Error -class AssertionClient(object): +class AssertionClient: """Constructs a new Assertion Framework for OAuth 2.0 Authorization Grants per RFC7521_. .. _RFC7521: https://tools.ietf.org/html/rfc7521 """ + DEFAULT_GRANT_TYPE = None ASSERTION_METHODS = {} token_auth_class = None - - def __init__(self, session, token_endpoint, issuer, subject, - audience=None, grant_type=None, claims=None, - token_placement='header', scope=None, **kwargs): - + oauth_error_class = OAuth2Error + + def __init__( + self, + session, + token_endpoint, + issuer, + subject, + audience=None, + grant_type=None, + claims=None, + token_placement="header", + scope=None, + leeway=60, + **kwargs, + ): self.session = session if audience is None: @@ -37,6 +49,7 @@ def __init__(self, session, token_endpoint, issuer, subject, if self.token_auth_class is not None: self.token_auth = self.token_auth_class(None, token_placement, self) self._kwargs = kwargs + self.leeway = leeway @property def token(self): @@ -58,27 +71,37 @@ def refresh_token(self): subject=self.subject, audience=self.audience, claims=self.claims, - **self._kwargs + **self._kwargs, ) data = { - 'assertion': to_native(assertion), - 'grant_type': self.grant_type, + "assertion": to_native(assertion), + "grant_type": self.grant_type, } if self.scope: - data['scope'] = self.scope + data["scope"] = self.scope return self._refresh_token(data) - def _refresh_token(self, data): - resp = self.session.request( - 'POST', self.token_endpoint, data=data, withhold_token=True) + def parse_response_token(self, resp): + if resp.status_code >= 500: + resp.raise_for_status() token = resp.json() - if 'error' in token: - raise OAuth2Error( - error=token['error'], - description=token.get('error_description') + if "error" in token: + raise self.oauth_error_class( + error=token["error"], description=token.get("error_description") ) self.token = token return self.token + + def _refresh_token(self, data): + resp = self.session.request( + "POST", self.token_endpoint, data=data, withhold_token=True + ) + + return self.parse_response_token(resp) + + def __del__(self): + if self.session: + del self.session diff --git a/authlib/oauth2/rfc7523/__init__.py b/authlib/oauth2/rfc7523/__init__.py index 843c07503..29dfd1c30 100644 --- a/authlib/oauth2/rfc7523/__init__.py +++ b/authlib/oauth2/rfc7523/__init__.py @@ -1,34 +1,31 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc7523 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc7523. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - JSON Web Token (JWT) Profile for OAuth 2.0 Client - Authentication and Authorization Grants. +This module represents a direct implementation of +JSON Web Token (JWT) Profile for OAuth 2.0 Client +Authentication and Authorization Grants. - https://tools.ietf.org/html/rfc7523 +https://tools.ietf.org/html/rfc7523 """ +from .assertion import client_secret_jwt_sign +from .assertion import private_key_jwt_sign +from .auth import ClientSecretJWT +from .auth import PrivateKeyJWT +from .client import JWTBearerClientAssertion from .jwt_bearer import JWTBearerGrant -from .client import ( - JWTBearerClientAssertion, -) -from .assertion import ( - client_secret_jwt_sign, - private_key_jwt_sign, -) -from .auth import ( - ClientSecretJWT, PrivateKeyJWT, - register_session_client_auth_method, -) +from .token import JWTBearerTokenGenerator +from .validator import JWTBearerToken +from .validator import JWTBearerTokenValidator __all__ = [ - 'JWTBearerGrant', - 'JWTBearerClientAssertion', - 'client_secret_jwt_sign', - 'private_key_jwt_sign', - 'ClientSecretJWT', - 'PrivateKeyJWT', - 'register_session_client_auth_method', + "JWTBearerGrant", + "JWTBearerClientAssertion", + "client_secret_jwt_sign", + "private_key_jwt_sign", + "ClientSecretJWT", + "PrivateKeyJWT", + "JWTBearerToken", + "JWTBearerTokenGenerator", + "JWTBearerTokenValidator", ] diff --git a/authlib/oauth2/rfc7523/assertion.py b/authlib/oauth2/rfc7523/assertion.py index 0bb9fe7be..47e7bc57b 100644 --- a/authlib/oauth2/rfc7523/assertion.py +++ b/authlib/oauth2/rfc7523/assertion.py @@ -1,49 +1,61 @@ import time -from authlib.jose import jwt + +from joserfc import jwt + +from authlib._joserfc_helpers import import_any_key from authlib.common.security import generate_token def sign_jwt_bearer_assertion( - key, issuer, audience, subject=None, issued_at=None, - expires_at=None, claims=None, header=None, **kwargs): - + key, + issuer, + audience, + subject=None, + issued_at=None, + expires_at=None, + claims=None, + header=None, + **kwargs, +): if header is None: header = {} - alg = kwargs.pop('alg', None) + alg = kwargs.pop("alg", None) if alg: - header['alg'] = alg - if 'alg' not in header: - raise ValueError('Missing "alg" in header') + header["alg"] = alg + if "alg" not in header: + raise ValueError("Missing 'alg' in header") - payload = {'iss': issuer, 'aud': audience} + payload = {"iss": issuer, "aud": audience} # subject is not required in Google service if subject: - payload['sub'] = subject + payload["sub"] = subject if not issued_at: issued_at = int(time.time()) - expires_in = kwargs.pop('expires_in', 3600) - if not expires_at: + expires_in = kwargs.pop("expires_in", 3600) + if expires_at is None: expires_at = issued_at + expires_in - payload['iat'] = issued_at - payload['exp'] = expires_at + payload["iat"] = issued_at + payload["exp"] = expires_at if claims: payload.update(claims) - return jwt.encode(header, payload, key) + return jwt.encode(header, payload, import_any_key(key)) -def client_secret_jwt_sign(client_secret, client_id, token_endpoint, alg='HS256', - claims=None, **kwargs): +def client_secret_jwt_sign( + client_secret, client_id, token_endpoint, alg="HS256", claims=None, **kwargs +): return _sign(client_secret, client_id, token_endpoint, alg, claims, **kwargs) -def private_key_jwt_sign(private_key, client_id, token_endpoint, alg='RS256', - claims=None, **kwargs): +def private_key_jwt_sign( + private_key, client_id, token_endpoint, alg="RS256", claims=None, **kwargs +): return _sign(private_key, client_id, token_endpoint, alg, claims, **kwargs) @@ -58,9 +70,15 @@ def _sign(key, client_id, token_endpoint, alg, claims=None, **kwargs): # jti is required if claims is None: claims = {} - if 'jti' not in claims: - claims['jti'] = generate_token(36) + if "jti" not in claims: + claims["jti"] = generate_token(36) return sign_jwt_bearer_assertion( - key=key, issuer=issuer, audience=audience, subject=subject, - claims=claims, alg=alg, **kwargs) + key=key, + issuer=issuer, + audience=audience, + subject=subject, + claims=claims, + alg=alg, + **kwargs, + ) diff --git a/authlib/oauth2/rfc7523/auth.py b/authlib/oauth2/rfc7523/auth.py index dddddc0b7..3da2a9595 100644 --- a/authlib/oauth2/rfc7523/auth.py +++ b/authlib/oauth2/rfc7523/auth.py @@ -1,10 +1,14 @@ +from joserfc.jwk import OctKey +from joserfc.jwk import RSAKey + from authlib.common.urls import add_params_to_qs -from authlib.deprecate import deprecate -from .assertion import client_secret_jwt_sign, private_key_jwt_sign + +from .assertion import client_secret_jwt_sign +from .assertion import private_key_jwt_sign from .client import ASSERTION_TYPE -class ClientSecretJWT(object): +class ClientSecretJWT: """Authentication method for OAuth 2.0 Client. This authentication method is called ``client_secret_jwt``, which is using ``client_id`` and ``client_secret`` constructed with JWT to identify a client. @@ -13,29 +17,43 @@ class ClientSecretJWT(object): from authlib.integrations.requests_client import OAuth2Session - token_endpoint = 'https://example.com/oauth/token' + token_endpoint = "https://example.com/oauth/token" session = OAuth2Session( - 'your-client-id', 'your-client-secret', - token_endpoint_auth_method='client_secret_jwt' + "your-client-id", + "your-client-secret", + token_endpoint_auth_method="client_secret_jwt", ) session.register_client_auth_method(ClientSecretJWT(token_endpoint)) session.fetch_token(token_endpoint) :param token_endpoint: A string URL of the token endpoint :param claims: Extra JWT claims + :param headers: Extra JWT headers + :param alg: ``alg`` value, default is HS256 """ - name = 'client_secret_jwt' - def __init__(self, token_endpoint=None, claims=None): + name = "client_secret_jwt" + alg = "HS256" + + def __init__(self, token_endpoint=None, claims=None, headers=None, alg=None): self.token_endpoint = token_endpoint self.claims = claims + self.headers = headers + if alg is not None: + self.alg = alg def sign(self, auth, token_endpoint): + if isinstance(auth.client_secret, OctKey): + key = auth.client_secret + else: + key = OctKey.import_key(auth.client_secret) return client_secret_jwt_sign( - auth.client_secret, + key, client_id=auth.client_id, token_endpoint=token_endpoint, claims=self.claims, + header=self.headers, + alg=self.alg, ) def __call__(self, auth, method, uri, headers, body): @@ -44,10 +62,13 @@ def __call__(self, auth, method, uri, headers, body): token_endpoint = uri client_assertion = self.sign(auth, token_endpoint) - body = add_params_to_qs(body or '', [ - ('client_assertion_type', ASSERTION_TYPE), - ('client_assertion', client_assertion) - ]) + body = add_params_to_qs( + body or "", + [ + ("client_assertion_type", ASSERTION_TYPE), + ("client_assertion", client_assertion), + ], + ) return uri, headers, body @@ -60,41 +81,34 @@ class PrivateKeyJWT(ClientSecretJWT): from authlib.integrations.requests_client import OAuth2Session - token_endpoint = 'https://example.com/oauth/token' + token_endpoint = "https://example.com/oauth/token" session = OAuth2Session( - 'your-client-id', 'your-client-private-key', - token_endpoint_auth_method='private_key_jwt' + "your-client-id", + "your-client-private-key", + token_endpoint_auth_method="private_key_jwt", ) session.register_client_auth_method(PrivateKeyJWT(token_endpoint)) session.fetch_token(token_endpoint) :param token_endpoint: A string URL of the token endpoint :param claims: Extra JWT claims + :param headers: Extra JWT headers + :param alg: ``alg`` value, default is RS256 """ - name = 'private_key_jwt' + + name = "private_key_jwt" + alg = "RS256" def sign(self, auth, token_endpoint): + if isinstance(auth.client_secret, RSAKey): + key = auth.client_secret + else: + key = RSAKey.import_key(auth.client_secret) return private_key_jwt_sign( - auth.client_secret, + key, client_id=auth.client_id, token_endpoint=token_endpoint, claims=self.claims, + header=self.headers, + alg=self.alg, ) - - -def register_session_client_auth_method(session, token_url=None, **kwargs): # pragma: no cover - """Register "client_secret_jwt" or "private_key_jwt" token endpoint auth - method to OAuth2Session. - - :param session: OAuth2Session instance. - :param token_url: Optional token endpoint url. - """ - deprecate('Use `ClientSecretJWT` and `PrivateKeyJWT` instead', '1.0', 'Jeclj', 'ca') - if session.token_endpoint_auth_method == 'client_secret_jwt': - cls = ClientSecretJWT - elif session.token_endpoint_auth_method == 'private_key_jwt': - cls = PrivateKeyJWT - else: - raise ValueError('Invalid token_endpoint_auth_method') - - session.register_client_auth_method(cls(token_url)) diff --git a/authlib/oauth2/rfc7523/client.py b/authlib/oauth2/rfc7523/client.py index cda82c84a..35d767551 100644 --- a/authlib/oauth2/rfc7523/client.py +++ b/authlib/oauth2/rfc7523/client.py @@ -1,49 +1,113 @@ +from __future__ import annotations + import logging -from authlib.jose import jwt -from authlib.jose.errors import JoseError + +from joserfc import jwk +from joserfc import jws +from joserfc import jwt +from joserfc.errors import JoseError +from joserfc.util import to_bytes + +from authlib._joserfc_helpers import import_any_key +from authlib.common.encoding import json_loads +from authlib.deprecate import deprecate + from ..rfc6749 import InvalidClientError -ASSERTION_TYPE = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer' +ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" log = logging.getLogger(__name__) -class JWTBearerClientAssertion(object): +class JWTBearerClientAssertion: """Implementation of Using JWTs for Client Authentication, which is defined by RFC7523. """ + #: Value of ``client_assertion_type`` of JWTs CLIENT_ASSERTION_TYPE = ASSERTION_TYPE #: Name of the client authentication method - CLIENT_AUTH_METHOD = 'client_assertion_jwt' + CLIENT_AUTH_METHOD = "client_assertion_jwt" - def __init__(self, token_url, validate_jti=True): + def __init__(self, token_url=None, validate_jti=True, leeway=60): + if token_url is not None: # pragma: no cover + deprecate( + "'token_url' is deprecated. Override 'get_audiences' instead.", + version="1.8", + ) self.token_url = token_url self._validate_jti = validate_jti + # A small allowance of time, typically no more than a few minutes, + # to account for clock skew. The default is 60 seconds. + self.leeway = leeway def __call__(self, query_client, request): data = request.form - assertion_type = data.get('client_assertion_type') - assertion = data.get('client_assertion') + assertion_type = data.get("client_assertion_type") + assertion = data.get("client_assertion") if assertion_type == ASSERTION_TYPE and assertion: - resolve_key = self.create_resolve_key_func(query_client, request) - self.process_assertion_claims(assertion, resolve_key) + headers, claims = self.extract_assertion(assertion) + client_id = claims["sub"] + client = query_client(client_id) + if not client: + raise InvalidClientError( + description="The client does not exist on this server." + ) + + try: + key = import_any_key(self.resolve_client_public_key(client)) + except TypeError: # pragma: no cover + key = import_any_key(self.resolve_client_public_key(client, headers)) + deprecate( + "resolve_client_public_key takes only 'client' parameter.", + version="1.8", + ) + + request.client = client + self.process_assertion_claims(assertion, key) return self.authenticate_client(request.client) - log.debug('Authenticate via %r failed', self.CLIENT_AUTH_METHOD) + log.debug("Authenticate via %r failed", self.CLIENT_AUTH_METHOD) - def create_claims_options(self): - """Create a claims_options for verify JWT payload claims. Developers - MAY overwrite this method to create a more strict options.""" - # https://tools.ietf.org/html/rfc7523#section-3 - # The Audience SHOULD be the URL of the Authorization Server's Token Endpoint + def verify_claims(self, claims: jwt.Claims): + # iss and sub MUST be the client_id options = { - 'iss': {'essential': True, 'validate': _validate_iss}, - 'sub': {'essential': True}, - 'aud': {'essential': True, 'value': self.token_url}, - 'exp': {'essential': True}, + "iss": {"essential": True}, + "sub": {"essential": True}, + "aud": {"essential": True, "values": self.get_audiences()}, + "exp": {"essential": True}, } + claims_requests = jwt.JWTClaimsRegistry(leeway=self.leeway, **options) + + try: + claims_requests.validate(claims) + except JoseError as e: + log.debug("Assertion Error: %r", e) + raise InvalidClientError(description=e.description) from e + + if claims["sub"] != claims["iss"]: + raise InvalidClientError(description="Issuer and Subject MUST match.") + if self._validate_jti: - options['jti'] = {'essential': True, 'validate': self.validate_jti} - return options + if "jti" not in claims: + raise InvalidClientError(description="Missing JWT ID.") + + if not self.validate_jti(claims, claims["jti"]): + raise InvalidClientError(description="JWT ID is used before.") + + def get_audiences(self): + """Return a list of valid audience identifiers for this authorization + server. Per RFC 7523 Section 3, the audience identifies the + authorization server as an intended audience. + + Developers MUST implement this method:: + + def get_audiences(self): + return ["https://example.com/oauth/token", "https://example.com"] + + :return: list of valid audience strings + """ + if self.token_url is not None: # pragma: no cover + return [self.token_url] + raise NotImplementedError() # pragma: no cover def process_assertion_claims(self, assertion, resolve_key): """Extract JWT payload claims from request "assertion", per @@ -57,40 +121,35 @@ def process_assertion_claims(self, assertion, resolve_key): .. _`Section 3.1`: https://tools.ietf.org/html/rfc7523#section-3.1 """ try: - claims = jwt.decode( - assertion, resolve_key, - claims_options=self.create_claims_options() - ) - claims.validate() + token = jwt.decode(assertion, resolve_key) except JoseError as e: - log.debug('Assertion Error: %r', e) - raise InvalidClientError() - return claims + log.debug("Assertion Error: %r", e) + raise InvalidClientError(description=e.description) from e + + self.verify_claims(token.claims) + return token.claims def authenticate_client(self, client): - if client.check_token_endpoint_auth_method(self.CLIENT_AUTH_METHOD): + if client.check_endpoint_auth_method(self.CLIENT_AUTH_METHOD, "token"): return client - raise InvalidClientError() - - def create_resolve_key_func(self, query_client, request): - def resolve_key(headers, payload): - # https://tools.ietf.org/html/rfc7523#section-3 - # For client authentication, the subject MUST be the - # "client_id" of the OAuth client - client_id = payload['sub'] - client = query_client(client_id) - if not client: - raise InvalidClientError() - request.client = client - return self.resolve_client_public_key(client, headers) - return resolve_key + raise InvalidClientError( + description=f"The client cannot authenticate with method: {self.CLIENT_AUTH_METHOD}" + ) + + def extract_assertion(self, assertion: str): + obj = jws.extract_compact(to_bytes(assertion)) + try: + claims = json_loads(obj.payload) + except ValueError: + raise InvalidClientError(description="Invalid JWT payload.") from None + return obj.headers(), claims def validate_jti(self, claims, jti): """Validate if the given ``jti`` value is used before. Developers MUST implement this method:: def validate_jti(self, claims, jti): - key = 'jti:{}-{}'.format(claims['sub'], jti) + key = "jti:{}-{}".format(claims["sub"], jti) if redis.get(key): return False redis.set(key, 1, ex=3600) @@ -98,16 +157,14 @@ def validate_jti(self, claims, jti): """ raise NotImplementedError() - def resolve_client_public_key(self, client, headers): + def resolve_client_public_key(self, client) -> jwk.Key | jwk.KeySet: """Resolve the client public key for verifying the JWT signature. - A client may have many public keys, in this case, we can retrieve it - via ``kid`` value in headers. Developers MUST implement this method:: + Developers MUST implement this method:: - def resolve_client_public_key(self, client, headers): - return client.public_key - """ - raise NotImplementedError() + from joserfc.jwk import KeySet -def _validate_iss(claims, iss): - return claims['sub'] == iss + def resolve_client_public_key(self, client): + return KeySet.import_key_set(client.public_jwks) + """ + raise NotImplementedError() diff --git a/authlib/oauth2/rfc7523/jwt_bearer.py b/authlib/oauth2/rfc7523/jwt_bearer.py index a11336d55..9954a3eae 100644 --- a/authlib/oauth2/rfc7523/jwt_bearer.py +++ b/authlib/oauth2/rfc7523/jwt_bearer.py @@ -1,39 +1,75 @@ import logging -from authlib.jose import jwt -from authlib.jose.errors import JoseError -from ..rfc6749 import BaseGrant, TokenEndpointMixin -from ..rfc6749 import ( - UnauthorizedClientError, - InvalidRequestError, - InvalidGrantError -) + +from joserfc import jwk +from joserfc import jws +from joserfc import jwt +from joserfc.errors import JoseError +from joserfc.util import to_bytes + +from authlib._joserfc_helpers import import_any_key +from authlib.common.encoding import json_loads +from authlib.deprecate import deprecate + +from ..rfc6749 import BaseGrant +from ..rfc6749 import InvalidClientError +from ..rfc6749 import InvalidGrantError +from ..rfc6749 import InvalidRequestError +from ..rfc6749 import TokenEndpointMixin +from ..rfc6749 import UnauthorizedClientError from .assertion import sign_jwt_bearer_assertion log = logging.getLogger(__name__) -JWT_BEARER_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:jwt-bearer' +JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" class JWTBearerGrant(BaseGrant, TokenEndpointMixin): GRANT_TYPE = JWT_BEARER_GRANT_TYPE + #: Options for verifying JWT payload claims. Developers MAY + #: overwrite this constant to create a more strict options. + CLAIMS_OPTIONS = { + "iss": {"essential": True}, + "aud": {"essential": True}, + "exp": {"essential": True}, + } + + # A small allowance of time, typically no more than a few minutes, + # to account for clock skew. The default is 60 seconds. + LEEWAY = 60 + @staticmethod - def sign(key, issuer, audience, subject=None, - issued_at=None, expires_at=None, claims=None, **kwargs): + def sign( + key, + issuer, + audience, + subject=None, + issued_at=None, + expires_at=None, + claims=None, + **kwargs, + ): return sign_jwt_bearer_assertion( - key, issuer, audience, subject, issued_at, - expires_at, claims, **kwargs) + key, issuer, audience, subject, issued_at, expires_at, claims, **kwargs + ) - def create_claims_options(self): - """Create a claims_options for verify JWT payload claims. Developers - MAY overwrite this method to create a more strict options. - """ - # https://tools.ietf.org/html/rfc7523#section-3 - return { - 'iss': {'essential': True}, - 'sub': {'essential': True}, - 'aud': {'essential': True}, - 'exp': {'essential': True}, - } + def verify_claims(self, claims: jwt.Claims): + options = dict(self.CLAIMS_OPTIONS) + audiences = self.get_audiences() + if audiences: + options["aud"] = {"essential": True, "values": audiences} + else: + deprecate( + "'get_audiences' must return a non-empty list. " + "Audience validation will become mandatory.", + version="1.8", + ) + + claims_requests = jwt.JWTClaimsRegistry(leeway=self.LEEWAY, **options) + try: + claims_requests.validate(claims) + except JoseError as e: + log.debug("Assertion Error: %r", e) + raise InvalidGrantError(description=e.description) from e def process_assertion_claims(self, assertion): """Extract JWT payload claims from request "assertion", per @@ -45,15 +81,37 @@ def process_assertion_claims(self, assertion): .. _`Section 3.1`: https://tools.ietf.org/html/rfc7523#section-3.1 """ - claims = jwt.decode( - assertion, self.resolve_public_key, - claims_options=self.create_claims_options()) + headers, claims = self.extract_assertion(assertion) + client = self.resolve_issuer_client(claims["iss"]) + + if hasattr(self, "resolve_client_key"): # pragma: no cover + key = import_any_key(self.resolve_client_key(client, headers, claims)) + deprecate( + "Use resolve_client_public_key instead of resolve_client_key.", + version="1.8", + ) + else: + key = import_any_key(self.resolve_client_public_key(client)) + try: - claims.validate() + token = jwt.decode(assertion, key) except JoseError as e: - log.debug('Assertion Error: %r', e) - raise InvalidGrantError(description=e.description) - return claims + log.debug("Assertion Error: %r", e) + raise InvalidGrantError(description=e.description) from e + except ValueError as e: + log.debug("Assertion Error: %r", e) + raise InvalidGrantError("Invalid JWT assertion") from None + + self.verify_claims(token.claims) + return token.claims + + def extract_assertion(self, assertion: str): + obj = jws.extract_compact(to_bytes(assertion)) + try: + claims = json_loads(obj.payload) + except ValueError: + raise InvalidGrantError(description="Invalid JWT payload.") from None + return obj.headers(), claims def validate_token_request(self): """The client makes a request to the token endpoint by sending the @@ -86,70 +144,119 @@ def validate_token_request(self): .. _`Section 2.1`: https://tools.ietf.org/html/rfc7523#section-2.1 """ - assertion = self.request.form.get('assertion') + assertion = self.request.form.get("assertion") if not assertion: - raise InvalidRequestError('Missing "assertion" in request') + raise InvalidRequestError("Missing 'assertion' in request") claims = self.process_assertion_claims(assertion) - client = self.authenticate_client(claims) - log.debug('Validate token request of %s', client) + client = self.resolve_issuer_client(claims["iss"]) + log.debug("Validate token request of %s", client) if not client.check_grant_type(self.GRANT_TYPE): - raise UnauthorizedClientError() + raise UnauthorizedClientError( + f"The client is not authorized to use 'grant_type={self.GRANT_TYPE}'" + ) self.request.client = client self.validate_requested_scope() - self.request.user = self.authenticate_user(client, claims) + + subject = claims.get("sub") + if subject: + user = self.authenticate_user(subject) + if not user: + raise InvalidGrantError(description="Invalid 'sub' value in assertion") + + log.debug("Check client(%s) permission to User(%s)", client, user) + if not self.has_granted_permission(client, user): + raise InvalidClientError( + description="Client has no permission to access user data" + ) + self.request.user = user def create_token_response(self): """If valid and authorized, the authorization server issues an access token. """ token = self.generate_token( - scope=self.request.scope, + scope=self.request.payload.scope, + user=self.request.user, include_refresh_token=False, ) - log.debug('Issue token %r to %r', token, self.request.client) + log.debug("Issue token %r to %r", token, self.request.client) self.save_token(token) return 200, token, self.TOKEN_RESPONSE_HEADER - def authenticate_user(self, client, claims): + def resolve_issuer_client(self, issuer): + """Fetch client via "iss" in assertion claims. Developers MUST + implement this method in subclass, e.g.:: + + def resolve_issuer_client(self, issuer): + return Client.query_by_iss(issuer) + + :param issuer: "iss" value in assertion + :return: Client instance + """ + raise NotImplementedError() + + def resolve_client_public_key(self, client) -> jwk.Key | jwk.KeySet: + """Resolve client key to decode assertion data. Developers MUST + implement this method in subclass. For instance, there is a + "jwks" column on client table, e.g.:: + + def resolve_client_public_key(self, client): + from joserfc import KeySet + + key_set = KeySet.import_key_set(client.jwks) + return key_set + + :param client: instance of OAuth client model + :return: OctKey, RSAKey, ECKey, OKPKey or KeySet instance + """ + raise NotImplementedError() + + def authenticate_user(self, subject): """Authenticate user with the given assertion claims. Developers MUST implement it in subclass, e.g.:: - def authenticate_user(self, client, claims): - user = User.get_by_sub(claims['sub']) - if is_authorized_to_client(user, client): - return user + def authenticate_user(self, subject): + return User.get_by_sub(subject) - :param client: OAuth Client instance - :param claims: assertion payload claims + :param subject: "sub" value in claims :return: User instance """ raise NotImplementedError() - def authenticate_client(self, claims): - """Authenticate client with the given assertion claims. Developers MUST - implement it in subclass, e.g.:: + def get_audiences(self): + """Return a list of valid audience identifiers for this authorization + server. Per RFC 7523 Section 3: - def authenticate_client(self, claims): - return Client.get_by_iss(claims['iss']) + The authorization server MUST reject any JWT that does not + contain its own identity as the intended audience. - :param claims: assertion payload claims - :return: Client instance + Developers SHOULD implement this method to return the list of valid + audience values, typically including the token endpoint URL and/or + the issuer identifier. For example:: + + def get_audiences(self): + return ["https://example.com/oauth/token", "https://example.com"] + + If this method returns an empty list, audience value validation is + skipped (only presence is checked). + + :return: list of valid audience strings """ - raise NotImplementedError() + return [] - def resolve_public_key(self, headers, payload): - """Find public key to verify assertion signature. Developers MUST - implement it in subclass, e.g.:: + def has_granted_permission(self, client, user): + """Check if the client has permission to access the given user's resource. + Developers MUST implement it in subclass, e.g.:: - def resolve_public_key(self, headers, payload): - jwk_set = get_jwk_set_by_iss(payload['iss']) - return filter_jwk_set(jwk_set, headers['kid']) + def has_granted_permission(self, client, user): + permission = ClientUserGrant.query(client=client, user=user) + return permission.granted - :param headers: JWT headers dict - :param payload: JWT payload dict - :return: A public key + :param client: instance of OAuth client model + :param user: instance of User model + :return: bool """ raise NotImplementedError() diff --git a/authlib/oauth2/rfc7523/token.py b/authlib/oauth2/rfc7523/token.py new file mode 100644 index 000000000..3122d5a04 --- /dev/null +++ b/authlib/oauth2/rfc7523/token.py @@ -0,0 +1,108 @@ +import time + +from joserfc import jwt + +from authlib._joserfc_helpers import import_any_key + + +class JWTBearerTokenGenerator: + """A JSON Web Token formatted bearer token generator for jwt-bearer grant type. + This token generator can be registered into authorization server:: + + authorization_server.register_token_generator( + "urn:ietf:params:oauth:grant-type:jwt-bearer", + JWTBearerTokenGenerator(private_rsa_key), + ) + + In this way, we can generate the token into JWT format. And we don't have to + save this token into database, since it will be short time valid. Consider to + rewrite ``JWTBearerGrant.save_token``:: + + class MyJWTBearerGrant(JWTBearerGrant): + def save_token(self, token): + pass + + :param secret_key: private RSA key in bytes, JWK or JWK Set. + :param issuer: a string or URI of the issuer + :param alg: ``alg`` to use in JWT + """ + + DEFAULT_EXPIRES_IN = 3600 + + def __init__(self, secret_key, issuer=None, alg="RS256"): + self.secret_key = import_any_key(secret_key) + self.issuer = issuer + self.alg = alg + + @staticmethod + def get_allowed_scope(client, scope): + if scope: + scope = client.get_allowed_scope(scope) + return scope + + @staticmethod + def get_sub_value(user): + """Return user's ID as ``sub`` value in token payload. For instance:: + + @staticmethod + def get_sub_value(user): + return str(user.id) + """ + return user.get_user_id() + + def get_token_data(self, grant_type, client, expires_in, user=None, scope=None): + scope = self.get_allowed_scope(client, scope) + issued_at = int(time.time()) + data = { + "scope": scope, + "grant_type": grant_type, + "iat": issued_at, + "exp": issued_at + expires_in, + "client_id": client.get_client_id(), + } + if self.issuer: + data["iss"] = self.issuer + if user: + data["sub"] = self.get_sub_value(user) + return data + + def generate(self, grant_type, client, user=None, scope=None, expires_in=None): + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :return: Token dict + """ + if expires_in is None: + expires_in = self.DEFAULT_EXPIRES_IN + + token_data = self.get_token_data(grant_type, client, expires_in, user, scope) + access_token = jwt.encode( + {"alg": self.alg}, + claims=token_data, + key=self.secret_key, + algorithms=[self.alg], + ) + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": expires_in, + } + if scope: + token["scope"] = scope + return token + + def __call__( + self, + grant_type, + client, + user=None, + scope=None, + expires_in=None, + include_refresh_token=True, + ): + # there is absolutely no refresh token in JWT format + return self.generate(grant_type, client, user, scope, expires_in) diff --git a/authlib/oauth2/rfc7523/validator.py b/authlib/oauth2/rfc7523/validator.py new file mode 100644 index 000000000..62a0fe2c3 --- /dev/null +++ b/authlib/oauth2/rfc7523/validator.py @@ -0,0 +1,62 @@ +import logging +import time + +from joserfc import jwt +from joserfc.errors import JoseError + +from authlib._joserfc_helpers import import_any_key + +from ..rfc6749 import TokenMixin +from ..rfc6750 import BearerTokenValidator + +logger = logging.getLogger(__name__) + + +class JWTBearerToken(TokenMixin, dict): + def check_client(self, client): + return self["client_id"] == client.get_client_id() + + def get_scope(self): + return self.get("scope") + + def get_expires_in(self): + return self["exp"] - self["iat"] + + def is_expired(self): + return self["exp"] < time.time() + + def is_revoked(self): + return False + + +class JWTBearerTokenValidator(BearerTokenValidator): + TOKEN_TYPE = "bearer" + token_cls = JWTBearerToken + + def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): + super().__init__(realm, **extra_attributes) + self.public_key = import_any_key(public_key) + claims_options = { + "exp": {"essential": True}, + "client_id": {"essential": True}, + "grant_type": {"essential": True}, + } + if issuer: + claims_options["iss"] = {"essential": True, "value": issuer} + self.claims_options = claims_options + + def authenticate_token(self, token_string: str): + try: + token = jwt.decode(token_string, self.public_key) + except JoseError as error: + logger.debug("Authenticate token failed. %r", error) + return None + + claims_requests = jwt.JWTClaimsRegistry(leeway=60, **self.claims_options) + try: + claims_requests.validate(token.claims) + except JoseError as error: + logger.debug("Authenticate token failed. %r", error) + return None + + return JWTBearerToken(token.claims) diff --git a/authlib/oauth2/rfc7591/__init__.py b/authlib/oauth2/rfc7591/__init__.py index 8ebb07096..8b25365d1 100644 --- a/authlib/oauth2/rfc7591/__init__.py +++ b/authlib/oauth2/rfc7591/__init__.py @@ -1,25 +1,24 @@ -""" - authlib.oauth2.rfc7591 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc7591. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - OAuth 2.0 Dynamic Client Registration Protocol. +This module represents a direct implementation of +OAuth 2.0 Dynamic Client Registration Protocol. - https://tools.ietf.org/html/rfc7591 +https://tools.ietf.org/html/rfc7591 """ - from .claims import ClientMetadataClaims from .endpoint import ClientRegistrationEndpoint -from .errors import ( - InvalidRedirectURIError, - InvalidClientMetadataError, - InvalidSoftwareStatementError, - UnapprovedSoftwareStatementError, -) +from .errors import InvalidClientMetadataError +from .errors import InvalidRedirectURIError +from .errors import InvalidSoftwareStatementError +from .errors import UnapprovedSoftwareStatementError __all__ = [ - 'ClientMetadataClaims', 'ClientRegistrationEndpoint', - 'InvalidRedirectURIError', 'InvalidClientMetadataError', - 'InvalidSoftwareStatementError', 'UnapprovedSoftwareStatementError', + "ClientMetadataClaims", + "ClientRegistrationEndpoint", + "InvalidRedirectURIError", + "InvalidClientMetadataError", + "InvalidSoftwareStatementError", + "UnapprovedSoftwareStatementError", ] diff --git a/authlib/oauth2/rfc7591/claims.py b/authlib/oauth2/rfc7591/claims.py index b6157b520..381c5b09d 100644 --- a/authlib/oauth2/rfc7591/claims.py +++ b/authlib/oauth2/rfc7591/claims.py @@ -1,30 +1,35 @@ -from authlib.jose import BaseClaims, JsonWebKey -from authlib.jose.errors import InvalidClaimError +from joserfc.errors import InvalidClaimError +from joserfc.errors import JoseError +from joserfc.jwk import KeySet + from authlib.common.urls import is_valid_url +from authlib.oauth2.claims import BaseClaims + +from ..rfc6749 import scope_to_list class ClientMetadataClaims(BaseClaims): # https://tools.ietf.org/html/rfc7591#section-2 REGISTERED_CLAIMS = [ - 'redirect_uris', - 'token_endpoint_auth_method', - 'grant_types', - 'response_types', - 'client_name', - 'client_uri', - 'logo_uri', - 'scope', - 'contacts', - 'tos_uri', - 'policy_uri', - 'jwks_uri', - 'jwks', - 'software_id', - 'software_version', + "redirect_uris", + "token_endpoint_auth_method", + "grant_types", + "response_types", + "client_name", + "client_uri", + "logo_uri", + "scope", + "contacts", + "tos_uri", + "policy_uri", + "jwks_uri", + "jwks", + "software_id", + "software_version", ] - def validate(self): - self._validate_essential_claims() + def validate(self, now=None, leeway=0): + super().validate(now, leeway) self.validate_redirect_uris() self.validate_token_endpoint_auth_method() self.validate_grant_types() @@ -50,31 +55,28 @@ def validate_redirect_uris(self): redirect-based flows MUST implement support for this metadata value. """ - uris = self.get('redirect_uris') + uris = self.get("redirect_uris") if uris: for uri in uris: - self._validate_uri('redirect_uris', uri) + self._validate_uri("redirect_uris", uri) def validate_token_endpoint_auth_method(self): """String indicator of the requested authentication method for the token endpoint. """ # If unspecified or omitted, the default is "client_secret_basic" - if 'token_endpoint_auth_method' not in self: - self['token_endpoint_auth_method'] = 'client_secret_basic' - self._validate_claim_value('token_endpoint_auth_method') + if "token_endpoint_auth_method" not in self: + self["token_endpoint_auth_method"] = "client_secret_basic" def validate_grant_types(self): """Array of OAuth 2.0 grant type strings that the client can use at the token endpoint. """ - self._validate_claim_value('grant_types') def validate_response_types(self): """Array of the OAuth 2.0 response type strings that the client can use at the authorization endpoint. """ - self._validate_claim_value('response_types') def validate_client_name(self): """Human-readable string name of the client to be presented to the @@ -93,7 +95,7 @@ def validate_client_uri(self): page. The value of this field MAY be internationalized, as described in Section 2.2. """ - self._validate_uri('client_uri') + self._validate_uri("client_uri") def validate_logo_uri(self): """URL string that references a logo for the client. If present, the @@ -102,7 +104,7 @@ def validate_logo_uri(self): value of this field MAY be internationalized, as described in Section 2.2. """ - self._validate_uri('logo_uri') + self._validate_uri("logo_uri") def validate_scope(self): """String containing a space-separated list of scope values (as @@ -111,7 +113,6 @@ def validate_scope(self): this list are service specific. If omitted, an authorization server MAY register a client with a default set of scopes. """ - self._validate_claim_value('scope') def validate_contacts(self): """Array of strings representing ways to contact people responsible @@ -120,8 +121,8 @@ def validate_contacts(self): support requests for the client. See Section 6 for information on Privacy Considerations. """ - if 'contacts' in self and not isinstance(self['contacts'], list): - raise InvalidClaimError('contacts') + if "contacts" in self and not isinstance(self["contacts"], list): + raise InvalidClaimError("contacts") def validate_tos_uri(self): """URL string that points to a human-readable terms of service @@ -132,7 +133,7 @@ def validate_tos_uri(self): field MUST point to a valid web page. The value of this field MAY be internationalized, as described in Section 2.2. """ - self._validate_uri('tos_uri') + self._validate_uri("tos_uri") def validate_policy_uri(self): """URL string that points to a human-readable privacy policy document @@ -142,7 +143,7 @@ def validate_policy_uri(self): value of this field MUST point to a valid web page. The value of this field MAY be internationalized, as described in Section 2.2. """ - self._validate_uri('policy_uri') + self._validate_uri("policy_uri") def validate_jwks_uri(self): """URL string referencing the client's JSON Web Key (JWK) Set @@ -158,7 +159,7 @@ def validate_jwks_uri(self): response. """ # TODO: use real HTTP library - self._validate_uri('jwks_uri') + self._validate_uri("jwks_uri") def validate_jwks(self): """Client's JSON Web Key Set [RFC7517] document value, which contains @@ -170,18 +171,16 @@ def validate_jwks(self): public URLs. The "jwks_uri" and "jwks" parameters MUST NOT both be present in the same request or response. """ - if 'jwks' in self: - if 'jwks_uri' in self: + if "jwks" in self: + if "jwks_uri" in self: # The "jwks_uri" and "jwks" parameters MUST NOT both be present - raise InvalidClaimError('jwks') + raise InvalidClaimError("jwks") - jwks = self['jwks'] + jwks = self["jwks"] try: - key_set = JsonWebKey.import_key_set(jwks) - if not key_set: - raise InvalidClaimError('jwks') - except ValueError: - raise InvalidClaimError('jwks') + KeySet.import_key_set(jwks) + except (JoseError, ValueError) as exc: + raise InvalidClaimError("jwks") from exc def validate_software_id(self): """A unique identifier string (e.g., a Universally Unique Identifier @@ -214,5 +213,59 @@ def validate_software_version(self): def _validate_uri(self, key, uri=None): if uri is None: uri = self.get(key) - if uri and not is_valid_url(uri): + if uri and not is_valid_url(uri, fragments_allowed=False): raise InvalidClaimError(key) + + @classmethod + def get_claims_options(cls, metadata): + """Generate claims options validation from Authorization Server metadata.""" + scopes_supported = metadata.get("scopes_supported") + response_types_supported = metadata.get("response_types_supported") + grant_types_supported = metadata.get("grant_types_supported") + auth_methods_supported = metadata.get("token_endpoint_auth_methods_supported") + options = {} + if scopes_supported is not None: + scopes_supported = set(scopes_supported) + + def _validate_scope(claims, value): + if not value: + return True + + scopes = set(scope_to_list(value)) + return scopes_supported.issuperset(scopes) + + options["scope"] = {"validate": _validate_scope} + + if response_types_supported is not None: + response_types_supported = [ + set(items.split()) for items in response_types_supported + ] + + def _validate_response_types(claims, value): + # If omitted, the default is that the client will use only the "code" + # response type. + response_types = ( + [set(items.split()) for items in value] if value else [{"code"}] + ) + return all( + response_type in response_types_supported + for response_type in response_types + ) + + options["response_types"] = {"validate": _validate_response_types} + + if grant_types_supported is not None: + grant_types_supported = set(grant_types_supported) + + def _validate_grant_types(claims, value): + # If omitted, the default behavior is that the client will use only + # the "authorization_code" Grant Type. + grant_types = set(value) if value else {"authorization_code"} + return grant_types_supported.issuperset(grant_types) + + options["grant_types"] = {"validate": _validate_grant_types} + + if auth_methods_supported is not None: + options["token_endpoint_auth_method"] = {"values": auth_methods_supported} + + return options diff --git a/authlib/oauth2/rfc7591/endpoint.py b/authlib/oauth2/rfc7591/endpoint.py index eff588cea..f7b69408a 100644 --- a/authlib/oauth2/rfc7591/endpoint.py +++ b/authlib/oauth2/rfc7591/endpoint.py @@ -1,34 +1,37 @@ +import binascii import os import time -import binascii -from authlib.consts import default_json_headers + +from joserfc import jwt +from joserfc.errors import JoseError + +from authlib._joserfc_helpers import import_any_key from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, JoseError -from ..rfc6749 import AccessDeniedError, InvalidRequestError -from ..rfc6749.util import scope_to_list +from authlib.consts import default_json_headers +from authlib.deprecate import deprecate + +from ..rfc6749 import AccessDeniedError +from ..rfc6749 import InvalidRequestError from .claims import ClientMetadataClaims -from .errors import ( - InvalidClientMetadataError, - UnapprovedSoftwareStatementError, - InvalidSoftwareStatementError, -) +from .errors import InvalidClientMetadataError +from .errors import InvalidSoftwareStatementError +from .errors import UnapprovedSoftwareStatementError -class ClientRegistrationEndpoint(object): +class ClientRegistrationEndpoint: """The client registration endpoint is an OAuth 2.0 endpoint designed to allow a client to be registered with the authorization server. """ - ENDPOINT_NAME = 'client_registration' - #: The claims validation class - claims_class = ClientMetadataClaims + ENDPOINT_NAME = "client_registration" #: Rewrite this value with a list to support ``software_statement`` #: e.g. ``software_statement_alg_values_supported = ['RS256']`` software_statement_alg_values_supported = None - def __init__(self, server): + def __init__(self, server=None, claims_classes=None): self.server = server + self.claims_classes = claims_classes or [ClientMetadataClaims] def __call__(self, request): return self.create_registration_response(request) @@ -41,7 +44,7 @@ def create_registration_response(self, request): request.credential = token client_metadata = self.extract_client_metadata(request) - client_info = self.generate_client_info() + client_info = self.generate_client_info(request) body = {} body.update(client_metadata) body.update(client_info) @@ -52,22 +55,31 @@ def create_registration_response(self, request): return 201, body, default_json_headers def extract_client_metadata(self, request): - if not request.data: + if not request.payload.data: raise InvalidRequestError() - json_data = request.data.copy() - software_statement = json_data.pop('software_statement', None) + json_data = request.payload.data.copy() + software_statement = json_data.pop("software_statement", None) if software_statement and self.software_statement_alg_values_supported: data = self.extract_software_statement(software_statement, request) json_data.update(data) - options = self.get_claims_options() - claims = self.claims_class(json_data, {}, options, self.server.metadata) - try: - claims.validate() - except JoseError as error: - raise InvalidClientMetadataError(error.description) - return claims.get_registered_claims() + client_metadata = {} + server_metadata = self.get_server_metadata() + for claims_class in self.claims_classes: + options = ( + claims_class.get_claims_options(server_metadata) + if hasattr(claims_class, "get_claims_options") and server_metadata + else {} + ) + claims = claims_class(json_data, {}, options, server_metadata) + try: + claims.validate() + except JoseError as error: + raise InvalidClientMetadataError(error.description) from error + + client_metadata.update(**claims.get_registered_claims()) + return client_metadata def extract_software_statement(self, software_statement, request): key = self.resolve_public_key(request) @@ -75,60 +87,36 @@ def extract_software_statement(self, software_statement, request): raise UnapprovedSoftwareStatementError() try: - jwt = JsonWebToken(self.software_statement_alg_values_supported) - claims = jwt.decode(software_statement, key) + key = import_any_key(key) + algorithms = self.software_statement_alg_values_supported + token = jwt.decode(software_statement, key, algorithms=algorithms) # there is no need to validate claims - return claims - except JoseError: - raise InvalidSoftwareStatementError() - - def get_claims_options(self): - """Generate claims options validation from Authorization Server metadata.""" - metadata = self.server.metadata - if not metadata: - return {} - - scopes_supported = metadata.get('scopes_supported') - response_types_supported = metadata.get('response_types_supported') - grant_types_supported = metadata.get('grant_types_supported') - auth_methods_supported = metadata.get('token_endpoint_auth_methods_supported') - options = {} - if scopes_supported is not None: - scopes_supported = set(scopes_supported) - - def _validate_scope(claims, value): - if not value: - return True - scopes = set(scope_to_list(value)) - return scopes_supported.issuperset(scopes) - - options['scope'] = {'validate': _validate_scope} - - if response_types_supported is not None: - response_types_supported = set(response_types_supported) + return token.claims + except JoseError as exc: + raise InvalidSoftwareStatementError() from exc - def _validate_response_types(claims, value): - return response_types_supported.issuperset(set(value)) - - options['response_types'] = {'validate': _validate_response_types} - - if grant_types_supported is not None: - grant_types_supported = set(grant_types_supported) - - def _validate_grant_types(claims, value): - return grant_types_supported.issuperset(set(value)) - - options['grant_types'] = {'validate': _validate_grant_types} - - if auth_methods_supported is not None: - options['token_endpoint_auth_method'] = {'values': auth_methods_supported} + def generate_client_info(self, request): + # https://tools.ietf.org/html/rfc7591#section-3.2.1 + try: + client_id = self.generate_client_id(request) + except TypeError: # pragma: no cover + client_id = self.generate_client_id() # type: ignore + deprecate( + "generate_client_id takes a 'request' parameter. " + "It will become mandatory in coming releases", + version="1.8", + ) - return options + try: + client_secret = self.generate_client_secret(request) + except TypeError: # pragma: no cover + client_secret = self.generate_client_secret() + deprecate( + "generate_client_secret takes a 'request' parameter. " + "It will become mandatory in coming releases", + version="1.8", + ) - def generate_client_info(self): - # https://tools.ietf.org/html/rfc7591#section-3.2.1 - client_id = self.generate_client_id() - client_secret = self.generate_client_secret() client_id_issued_at = int(time.time()) client_secret_expires_at = 0 return dict( @@ -141,30 +129,37 @@ def generate_client_info(self): def generate_client_registration_info(self, client, request): """Generate ```registration_client_uri`` and ``registration_access_token`` for RFC7592. This method returns ``None`` by default. Developers MAY rewrite - this method to return registration information.""" + this method to return registration information. + """ return None def create_endpoint_request(self, request): return self.server.create_json_request(request) - def generate_client_id(self): + def generate_client_id(self, request): """Generate ``client_id`` value. Developers MAY rewrite this method to use their own way to generate ``client_id``. """ return generate_token(42) - def generate_client_secret(self): + def generate_client_secret(self, request): """Generate ``client_secret`` value. Developers MAY rewrite this method to use their own way to generate ``client_secret``. """ - return binascii.hexlify(os.urandom(24)).decode('ascii') + return binascii.hexlify(os.urandom(24)).decode("ascii") + + def get_server_metadata(self): + """Return server metadata which includes supported grant types, + response types and etc. + """ + raise NotImplementedError() def authenticate_token(self, request): """Authenticate current credential who is requesting to register a client. Developers MUST implement this method in subclass:: def authenticate_token(self, request): - auth = request.headers.get('Authorization') + auth = request.headers.get("Authorization") return get_token_by_auth(auth) :return: token instance diff --git a/authlib/oauth2/rfc7591/errors.py b/authlib/oauth2/rfc7591/errors.py index 31693c047..4b6ed5b53 100644 --- a/authlib/oauth2/rfc7591/errors.py +++ b/authlib/oauth2/rfc7591/errors.py @@ -3,9 +3,10 @@ class InvalidRedirectURIError(OAuth2Error): """The value of one or more redirection URIs is invalid. - https://tools.ietf.org/html/rfc7591#section-3.2.2 + https://tools.ietf.org/html/rfc7591#section-3.2.2. """ - error = 'invalid_redirect_uri' + + error = "invalid_redirect_uri" class InvalidClientMetadataError(OAuth2Error): @@ -13,21 +14,24 @@ class InvalidClientMetadataError(OAuth2Error): server has rejected this request. Note that an authorization server MAY choose to substitute a valid value for any requested parameter of a client's metadata. - https://tools.ietf.org/html/rfc7591#section-3.2.2 + https://tools.ietf.org/html/rfc7591#section-3.2.2. """ - error = 'invalid_client_metadata' + + error = "invalid_client_metadata" class InvalidSoftwareStatementError(OAuth2Error): """The software statement presented is invalid. - https://tools.ietf.org/html/rfc7591#section-3.2.2 + https://tools.ietf.org/html/rfc7591#section-3.2.2. """ - error = 'invalid_software_statement' + + error = "invalid_software_statement" class UnapprovedSoftwareStatementError(OAuth2Error): """The software statement presented is not approved for use by this authorization server. - https://tools.ietf.org/html/rfc7591#section-3.2.2 + https://tools.ietf.org/html/rfc7591#section-3.2.2. """ - error = 'unapproved_software_statement' + + error = "unapproved_software_statement" diff --git a/authlib/oauth2/rfc7592/__init__.py b/authlib/oauth2/rfc7592/__init__.py index 6a6457be5..a5b3cb1c4 100644 --- a/authlib/oauth2/rfc7592/__init__.py +++ b/authlib/oauth2/rfc7592/__init__.py @@ -1,13 +1,12 @@ -""" - authlib.oauth2.rfc7592 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc7592. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - OAuth 2.0 Dynamic Client Registration Management Protocol. +This module represents a direct implementation of +OAuth 2.0 Dynamic Client Registration Management Protocol. - https://tools.ietf.org/html/rfc7592 +https://tools.ietf.org/html/rfc7592 """ from .endpoint import ClientConfigurationEndpoint -__all__ = ['ClientConfigurationEndpoint'] +__all__ = ["ClientConfigurationEndpoint"] diff --git a/authlib/oauth2/rfc7592/endpoint.py b/authlib/oauth2/rfc7592/endpoint.py index 5a036d714..ee9bf88a8 100644 --- a/authlib/oauth2/rfc7592/endpoint.py +++ b/authlib/oauth2/rfc7592/endpoint.py @@ -1,21 +1,31 @@ +from joserfc.errors import JoseError + from authlib.consts import default_json_headers + from ..rfc6749 import AccessDeniedError -from ..rfc6750 import InvalidTokenError +from ..rfc6749 import InvalidClientError +from ..rfc6749 import InvalidRequestError +from ..rfc6749 import UnauthorizedClientError +from ..rfc7591 import InvalidClientMetadataError +from ..rfc7591.claims import ClientMetadataClaims -class ClientConfigurationEndpoint(object): - ENDPOINT_NAME = 'client_configuration' +class ClientConfigurationEndpoint: + ENDPOINT_NAME = "client_configuration" - def __init__(self, server): + def __init__(self, server=None, claims_classes=None): self.server = server + self.claims_classes = claims_classes or [ClientMetadataClaims] def __call__(self, request): return self.create_configuration_response(request) def create_configuration_response(self, request): + # This request is authenticated by the registration access token issued + # to the client. token = self.authenticate_token(request) if not token: - raise InvalidTokenError() + raise AccessDeniedError() request.credential = token @@ -24,21 +34,26 @@ def create_configuration_response(self, request): # If the client does not exist on this server, the server MUST respond # with HTTP 401 Unauthorized and the registration access token used to # make this request SHOULD be immediately revoked. - self.revoke_access_token(request) - raise InvalidTokenError() + self.revoke_access_token(request, token) + raise InvalidClientError( + status_code=401, description="The client does not exist on this server." + ) if not self.check_permission(client, request): # If the client does not have permission to read its record, the server # MUST return an HTTP 403 Forbidden. - raise AccessDeniedError() + raise UnauthorizedClientError( + status_code=403, + description="The client does not have permission to read its record.", + ) request.client = client - if request.method == 'GET': + if request.method == "GET": return self.create_read_client_response(client, request) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self.create_delete_client_response(client, request) - elif request.method == 'PUT': + elif request.method == "PUT": return self.create_update_client_response(client, request) def create_endpoint_request(self, request): @@ -46,90 +61,88 @@ def create_endpoint_request(self, request): def create_read_client_response(self, client, request): body = self.introspect_client(client) - info = self.generate_client_registration_info(client, request) - body.update(info) + body.update(self.generate_client_registration_info(client, request)) return 200, body, default_json_headers def create_delete_client_response(self, client, request): - """To deprive itself on the authorization server, the client makes - an HTTP DELETE request to the client configuration endpoint. This - request is authenticated by the registration access token issued to - the client. - - The following is a non-normative example request:: - - DELETE /register/s6BhdRkqt3 HTTP/1.1 - Host: server.example.com - Authorization: Bearer reg-23410913-abewfq.123483 - """ self.delete_client(client, request) headers = [ - ('Cache-Control', 'no-store'), - ('Pragma', 'no-cache'), + ("Cache-Control", "no-store"), + ("Pragma", "no-cache"), ] - return 204, '', headers + return 204, "", headers def create_update_client_response(self, client, request): - """ To update a previously registered client's registration with an - authorization server, the client makes an HTTP PUT request to the - client configuration endpoint with a content type of "application/ - json". - - The following is a non-normative example request:: - - PUT /register/s6BhdRkqt3 HTTP/1.1 - Accept: application/json - Host: server.example.com - Authorization: Bearer reg-23410913-abewfq.123483 - - { - "client_id": "s6BhdRkqt3", - "client_secret": "cf136dc3c1fc93f31185e5885805d", - "redirect_uris": [ - "https://client.example.org/callback", - "https://client.example.org/alt" - ], - "grant_types": ["authorization_code", "refresh_token"], - "token_endpoint_auth_method": "client_secret_basic", - "jwks_uri": "https://client.example.org/my_public_keys.jwks", - "client_name": "My New Example", - "client_name#fr": "Mon Nouvel Exemple", - "logo_uri": "https://client.example.org/newlogo.png", - "logo_uri#fr": "https://client.example.org/fr/newlogo.png" - } - """ # The updated client metadata fields request MUST NOT include the - # "registration_access_token", "registration_client_uri", - # "client_secret_expires_at", or "client_id_issued_at" fields + # 'registration_access_token', 'registration_client_uri', + # 'client_secret_expires_at', or 'client_id_issued_at' fields must_not_include = ( - 'registration_access_token', 'registration_client_uri', - 'client_secret_expires_at', 'client_id_issued_at', + "registration_access_token", + "registration_client_uri", + "client_secret_expires_at", + "client_id_issued_at", ) for k in must_not_include: - if k in request.data: - return + if k in request.payload.data: + raise InvalidRequestError() - # The client MUST include its "client_id" field in the request - client_id = request.data.get('client_id') + # The client MUST include its 'client_id' field in the request + client_id = request.payload.data.get("client_id") if not client_id: - raise + raise InvalidRequestError() if client_id != client.get_client_id(): - raise + raise InvalidRequestError() - # If the client includes the "client_secret" field in the request, + # If the client includes the 'client_secret' field in the request, # the value of this field MUST match the currently issued client # secret for that client. - if 'client_secret' in request.data: - if not client.check_client_secret(request.data['client_secret']): - raise + if "client_secret" in request.payload.data: + if not client.check_client_secret(request.payload.data["client_secret"]): + raise InvalidRequestError() - client = self.save_client(client, request) + client_metadata = self.extract_client_metadata(request) + client = self.update_client(client, client_metadata, request) return self.create_read_client_response(client, request) + def extract_client_metadata(self, request): + json_data = request.payload.data.copy() + client_metadata = {} + server_metadata = self.get_server_metadata() + for claims_class in self.claims_classes: + options = ( + claims_class.get_claims_options(server_metadata) + if hasattr(claims_class, "get_claims_options") and server_metadata + else {} + ) + claims = claims_class(json_data, {}, options, server_metadata) + try: + claims.validate() + except JoseError as error: + print(error) + raise InvalidClientMetadataError(error.description) from error + + client_metadata.update(**claims.get_registered_claims()) + return client_metadata + + def introspect_client(self, client): + return {**client.client_info, **client.client_metadata} + def generate_client_registration_info(self, client, request): """Generate ```registration_client_uri`` and ``registration_access_token`` - for RFC7592. This method returns ``None`` by default. Developers MAY rewrite - this method to return registration information.""" + for RFC7592. By default this method returns the values sent in the current + request. Developers MUST rewrite this method to return different registration + information.:: + + def generate_client_registration_info(self, client, request):{ + access_token = request.headers['Authorization'].split(' ')[1] + return { + 'registration_client_uri': request.uri, + 'registration_access_token': access_token, + } + + :param client: the instance of OAuth client + :param request: formatted request instance + """ raise NotImplementedError() def authenticate_token(self, request): @@ -137,7 +150,7 @@ def authenticate_token(self, request): Developers MUST implement this method in subclass:: def authenticate_token(self, request): - auth = request.headers.get('Authorization') + auth = request.headers.get("Authorization") return get_token_by_auth(auth) :return: token instance @@ -145,15 +158,37 @@ def authenticate_token(self, request): raise NotImplementedError() def authenticate_client(self, request): + """Read a client from the request payload. + Developers MUST implement this method in subclass:: + + def authenticate_client(self, request): + client_id = request.payload.data.get("client_id") + return Client.get(client_id=client_id) + + :return: client instance + """ raise NotImplementedError() - def revoke_access_token(self, request): + def revoke_access_token(self, token, request): + """Revoke a token access in case an invalid client has been requested. + Developers MUST implement this method in subclass:: + + def revoke_access_token(self, token, request): + token.revoked = True + token.save() + + """ raise NotImplementedError() def check_permission(self, client, request): - raise NotImplementedError() + """Checks whether the current client is allowed to be accessed, edited + or deleted. Developers MUST implement it in subclass, e.g.:: - def introspect_client(self, client): + def check_permission(self, client, request): + return client.editable + + :return: boolean + """ raise NotImplementedError() def delete_client(self, client, request): @@ -168,5 +203,26 @@ def delete_client(self, client, request): """ raise NotImplementedError() - def save_client(self, client, request): + def update_client(self, client, client_metadata, request): + """Update the client in the database. Developers MUST implement this method + in subclass:: + + def update_client(self, client, client_metadata, request): + client.set_client_metadata( + {**client.client_metadata, **client_metadata} + ) + client.save() + return client + + :param client: the instance of OAuth client + :param client_metadata: a dict of the client claims to update + :param request: formatted request instance + :return: client instance + """ + raise NotImplementedError() + + def get_server_metadata(self): + """Return server metadata which includes supported grant types, + response types and etc. + """ raise NotImplementedError() diff --git a/authlib/oauth2/rfc7636/__init__.py b/authlib/oauth2/rfc7636/__init__.py index d943f3e1a..25399a586 100644 --- a/authlib/oauth2/rfc7636/__init__.py +++ b/authlib/oauth2/rfc7636/__init__.py @@ -1,14 +1,13 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc7636 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc7636. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - Proof Key for Code Exchange by OAuth Public Clients. +This module represents a direct implementation of +Proof Key for Code Exchange by OAuth Public Clients. - https://tools.ietf.org/html/rfc7636 +https://tools.ietf.org/html/rfc7636 """ -from .challenge import CodeChallenge, create_s256_code_challenge +from .challenge import CodeChallenge +from .challenge import create_s256_code_challenge -__all__ = ['CodeChallenge', 'create_s256_code_challenge'] +__all__ = ["CodeChallenge", "create_s256_code_challenge"] diff --git a/authlib/oauth2/rfc7636/challenge.py b/authlib/oauth2/rfc7636/challenge.py index 26159b8f1..b413fa7d1 100644 --- a/authlib/oauth2/rfc7636/challenge.py +++ b/authlib/oauth2/rfc7636/challenge.py @@ -1,15 +1,21 @@ -import re import hashlib -from authlib.common.encoding import to_bytes, to_unicode, urlsafe_b64encode -from ..rfc6749.errors import InvalidRequestError, InvalidGrantError +import re + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64encode +from ..rfc6749 import InvalidGrantError +from ..rfc6749 import InvalidRequestError +from ..rfc6749 import OAuth2Request -CODE_VERIFIER_PATTERN = re.compile(r'^[a-zA-Z0-9\-._~]{43,128}$') +CODE_VERIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9\-._~]{43,128}$") +CODE_CHALLENGE_PATTERN = re.compile(r"^[a-zA-Z0-9\-._~]{43,128}$") def create_s256_code_challenge(code_verifier): """Create S256 code_challenge with the given code_verifier.""" - data = hashlib.sha256(to_bytes(code_verifier, 'ascii')).digest() + data = hashlib.sha256(to_bytes(code_verifier, "ascii")).digest() return to_unicode(urlsafe_b64encode(data)) @@ -24,7 +30,7 @@ def compare_s256_code_challenge(code_verifier, code_challenge): return create_s256_code_challenge(code_verifier) == code_challenge -class CodeChallenge(object): +class CodeChallenge: """CodeChallenge extension to Authorization Code Grant. It is used to improve the security of Authorization Code flow for public clients by sending extra "code_challenge" and "code_verifier" to the authorization @@ -34,19 +40,17 @@ class CodeChallenge(object): ``code_challenge_method`` into database when ``save_authorization_code``. Then register this extension via:: - server.register_grant( - AuthorizationCodeGrant, - [CodeChallenge(required=True)] - ) + server.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) """ + #: defaults to "plain" if not present in the request - DEFAULT_CODE_CHALLENGE_METHOD = 'plain' + DEFAULT_CODE_CHALLENGE_METHOD = "plain" #: supported ``code_challenge_method`` - SUPPORTED_CODE_CHALLENGE_METHOD = ['plain', 'S256'] + SUPPORTED_CODE_CHALLENGE_METHOD = ["plain", "S256"] CODE_CHALLENGE_METHODS = { - 'plain': compare_plain_code_challenge, - 'S256': compare_s256_code_challenge, + "plain": compare_plain_code_challenge, + "S256": compare_s256_code_challenge, } def __init__(self, required=True): @@ -54,48 +58,66 @@ def __init__(self, required=True): def __call__(self, grant): grant.register_hook( - 'after_validate_authorization_request', + "after_validate_authorization_request_payload", self.validate_code_challenge, ) grant.register_hook( - 'after_validate_token_request', + "after_validate_token_request", self.validate_code_verifier, ) - def validate_code_challenge(self, grant): - request = grant.request - challenge = request.args.get('code_challenge') - method = request.args.get('code_challenge_method') + def validate_code_challenge(self, grant, redirect_uri): + request: OAuth2Request = grant.request + challenge = request.payload.data.get("code_challenge") + method = request.payload.data.get("code_challenge_method") if not challenge and not method: return if not challenge: - raise InvalidRequestError('Missing "code_challenge"') + raise InvalidRequestError("Missing 'code_challenge'") + + if len(request.payload.datalist.get("code_challenge", [])) > 1: + raise InvalidRequestError("Multiple 'code_challenge' in request.") + + if not CODE_CHALLENGE_PATTERN.match(challenge): + raise InvalidRequestError("Invalid 'code_challenge'") if method and method not in self.SUPPORTED_CODE_CHALLENGE_METHOD: - raise InvalidRequestError('Unsupported "code_challenge_method"') + raise InvalidRequestError("Unsupported 'code_challenge_method'") - def validate_code_verifier(self, grant): - request = grant.request - verifier = request.form.get('code_verifier') + if len(request.payload.datalist.get("code_challenge_method", [])) > 1: + raise InvalidRequestError("Multiple 'code_challenge_method' in request.") + + def validate_code_verifier(self, grant, result): + request: OAuth2Request = grant.request + verifier = request.form.get("code_verifier") # public client MUST verify code challenge - if self.required and request.auth_method == 'none' and not verifier: - raise InvalidRequestError('Missing "code_verifier"') + if self.required and request.auth_method == "none" and not verifier: + raise InvalidRequestError("Missing 'code_verifier'") - authorization_code = request.credential + authorization_code = request.authorization_code challenge = self.get_authorization_code_challenge(authorization_code) # ignore, it is the normal RFC6749 authorization_code request if not challenge and not verifier: return + # RFC 9700 Section 4.8: the authorization server MUST ensure that if + # there was no code_challenge in the authorization request, a request + # to the token endpoint containing a code_verifier is rejected. + if not challenge and verifier: + raise InvalidRequestError( + "The authorization request had no 'code_challenge', " + "but a 'code_verifier' was provided." + ) + # challenge exists, code_verifier is required if not verifier: - raise InvalidRequestError('Missing "code_verifier"') + raise InvalidRequestError("Missing 'code_verifier'") if not CODE_VERIFIER_PATTERN.match(verifier): - raise InvalidRequestError('Invalid "code_verifier"') + raise InvalidRequestError("Invalid 'code_verifier'") # 4.6. Server Verifies code_verifier before Returning the Tokens method = self.get_authorization_code_challenge_method(authorization_code) @@ -104,12 +126,12 @@ def validate_code_verifier(self, grant): func = self.CODE_CHALLENGE_METHODS.get(method) if not func: - raise RuntimeError('No verify method for "{}"'.format(method)) + raise RuntimeError(f"No verify method for '{method}'") # If the values are not equal, an error response indicating # "invalid_grant" MUST be returned. if not func(verifier, challenge): - raise InvalidGrantError(description='Code challenge failed.') + raise InvalidGrantError(description="Code challenge failed.") def get_authorization_code_challenge(self, authorization_code): """Get "code_challenge" associated with this authorization code. diff --git a/authlib/oauth2/rfc7662/__init__.py b/authlib/oauth2/rfc7662/__init__.py index 283776188..ada30736a 100644 --- a/authlib/oauth2/rfc7662/__init__.py +++ b/authlib/oauth2/rfc7662/__init__.py @@ -1,15 +1,14 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc7662 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc7662. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - OAuth 2.0 Token Introspection. +This module represents a direct implementation of +OAuth 2.0 Token Introspection. - https://tools.ietf.org/html/rfc7662 +https://tools.ietf.org/html/rfc7662 """ from .introspection import IntrospectionEndpoint from .models import IntrospectionToken +from .token_validator import IntrospectTokenValidator -__all__ = ['IntrospectionEndpoint', 'IntrospectionToken'] +__all__ = ["IntrospectionEndpoint", "IntrospectionToken", "IntrospectTokenValidator"] diff --git a/authlib/oauth2/rfc7662/introspection.py b/authlib/oauth2/rfc7662/introspection.py index 44c76c851..9ff7ea9e9 100644 --- a/authlib/oauth2/rfc7662/introspection.py +++ b/authlib/oauth2/rfc7662/introspection.py @@ -1,10 +1,8 @@ -import time from authlib.consts import default_json_headers -from ..rfc6749 import ( - TokenEndpoint, - InvalidRequestError, - UnsupportedTokenTypeError, -) + +from ..rfc6749 import InvalidRequestError +from ..rfc6749 import TokenEndpoint +from ..rfc6749 import UnsupportedTokenTypeError class IntrospectionEndpoint(TokenEndpoint): @@ -13,10 +11,11 @@ class IntrospectionEndpoint(TokenEndpoint): .. _RFC7662: https://tools.ietf.org/html/rfc7662 """ + #: Endpoint name to be registered - ENDPOINT_NAME = 'introspection' + ENDPOINT_NAME = "introspection" - def authenticate_endpoint_credential(self, request, client): + def authenticate_token(self, request, client): """The protected resource calls the introspection endpoint using an HTTP ``POST`` request with parameters sent as "application/x-www-form-urlencoded" data. The protected resource sends a @@ -35,16 +34,22 @@ def authenticate_endpoint_credential(self, request, client): **OPTIONAL** A hint about the type of the token submitted for introspection. """ + self.check_params(request, client) + token = self.query_token( + request.form["token"], request.form.get("token_type_hint") + ) + if token and self.check_permission(token, client, request): + return token + + def check_params(self, request, client): params = request.form - if 'token' not in params: + if "token" not in params: raise InvalidRequestError() - token_type = params.get('token_type_hint') - if token_type and token_type not in self.SUPPORTED_TOKEN_TYPES: + hint = params.get("token_type_hint") + if hint and hint not in self.SUPPORTED_TOKEN_TYPES: raise UnsupportedTokenTypeError() - return self.query_token(params['token'], token_type, client) - def create_endpoint_response(self, request): """Validate introspection request and create the response. @@ -55,10 +60,10 @@ def create_endpoint_response(self, request): # then verifies whether the token was issued to the client making # the revocation request - credential = self.authenticate_endpoint_credential(request, client) + token = self.authenticate_token(request, client) # the authorization server invalidates the token - body = self.create_introspection_payload(credential) + body = self.create_introspection_payload(token) return 200, body, default_json_headers def create_introspection_payload(self, token): @@ -67,31 +72,40 @@ def create_introspection_payload(self, token): # token, then the authorization server MUST return an introspection # response with the "active" field set to "false" if not token: - return {'active': False} - expires_at = token.get_expires_at() - if expires_at < time.time() or token.revoked: - return {'active': False} + return {"active": False} + if token.is_expired() or token.is_revoked(): + return {"active": False} payload = self.introspect_token(token) - if 'active' not in payload: - payload['active'] = True + if "active" not in payload: + payload["active"] = True return payload - def query_token(self, token, token_type_hint, client): + def check_permission(self, token, client, request): + """Check if the request has permission to introspect the token. Developers + MUST implement this method:: + + def check_permission(self, token, client, request): + # only allow a special client to introspect the token + return client.client_id == "introspection_client" + + :return: bool + """ + raise NotImplementedError() + + def query_token(self, token_string, token_type_hint): """Get the token from database/storage by the given token string. Developers should implement this method:: - def query_token(self, token, token_type_hint, client): - if token_type_hint == 'access_token': - tok = Token.query_by_access_token(token) - elif token_type_hint == 'refresh_token': - tok = Token.query_by_refresh_token(token) + def query_token(self, token_string, token_type_hint): + if token_type_hint == "access_token": + tok = Token.query_by_access_token(token_string) + elif token_type_hint == "refresh_token": + tok = Token.query_by_refresh_token(token_string) else: - tok = Token.query_by_access_token(token) + tok = Token.query_by_access_token(token_string) if not tok: - tok = Token.query_by_refresh_token(token) - - if check_client_permission(client, tok): - return tok + tok = Token.query_by_refresh_token(token_string) + return tok """ raise NotImplementedError() @@ -100,18 +114,17 @@ def introspect_token(self, token): dictionary following `Section 2.2`_:: def introspect_token(self, token): - active = is_token_active(token) return { - 'active': active, - 'client_id': token.client_id, - 'token_type': token.token_type, - 'username': get_token_username(token), - 'scope': token.get_scope(), - 'sub': get_token_user_sub(token), - 'aud': token.client_id, - 'iss': 'https://server.example.com/', - 'exp': token.expires_at, - 'iat': token.issued_at, + "active": True, + "client_id": token.client_id, + "token_type": token.token_type, + "username": get_token_username(token), + "scope": token.get_scope(), + "sub": get_token_user_sub(token), + "aud": token.client_id, + "iss": "https://server.example.com/", + "exp": token.expires_at, + "iat": token.issued_at, } .. _`Section 2.2`: https://tools.ietf.org/html/rfc7662#section-2.2 diff --git a/authlib/oauth2/rfc7662/models.py b/authlib/oauth2/rfc7662/models.py index 0f4f0c215..e369fa732 100644 --- a/authlib/oauth2/rfc7662/models.py +++ b/authlib/oauth2/rfc7662/models.py @@ -3,10 +3,10 @@ class IntrospectionToken(dict, TokenMixin): def get_client_id(self): - return self.get('client_id') + return self.get("client_id") def get_scope(self): - return self.get('scope') + return self.get("scope") def get_expires_in(self): # this method is only used in refresh token, @@ -14,13 +14,23 @@ def get_expires_in(self): return 0 def get_expires_at(self): - return self.get('exp', 0) + return self.get("exp", 0) def __getattr__(self, key): # https://tools.ietf.org/html/rfc7662#section-2.2 available_keys = { - 'active', 'scope', 'client_id', 'username', 'token_type', - 'exp', 'iat', 'nbf', 'sub', 'aud', 'iss', 'jti' + "active", + "scope", + "client_id", + "username", + "token_type", + "exp", + "iat", + "nbf", + "sub", + "aud", + "iss", + "jti", } try: return object.__getattribute__(self, key) diff --git a/authlib/oauth2/rfc7662/token_validator.py b/authlib/oauth2/rfc7662/token_validator.py new file mode 100644 index 000000000..213be5641 --- /dev/null +++ b/authlib/oauth2/rfc7662/token_validator.py @@ -0,0 +1,34 @@ +from ..rfc6749 import TokenValidator +from ..rfc6750 import InsufficientScopeError +from ..rfc6750 import InvalidTokenError + + +class IntrospectTokenValidator(TokenValidator): + TOKEN_TYPE = "bearer" + + def introspect_token(self, token_string): + """Request introspection token endpoint with the given token string, + authorization server will return token information in JSON format. + Developers MUST implement this method before using it:: + + def introspect_token(self, token_string): + # for example, introspection token endpoint has limited + # internal IPs to access, so there is no need to add + # authentication. + url = "https://example.com/oauth/introspect" + resp = requests.post(url, data={"token": token_string}) + resp.raise_for_status() + return resp.json() + """ + raise NotImplementedError() + + def authenticate_token(self, token_string): + return self.introspect_token(token_string) + + def validate_token(self, token, scopes, request): + if not token or not token["active"]: + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) + if self.scope_insufficient(token.get("scope"), scopes): + raise InsufficientScopeError() diff --git a/authlib/oauth2/rfc8414/__init__.py b/authlib/oauth2/rfc8414/__init__.py index 2cdbfbdc0..fff67209f 100644 --- a/authlib/oauth2/rfc8414/__init__.py +++ b/authlib/oauth2/rfc8414/__init__.py @@ -1,16 +1,13 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc8414 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc8414. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents a direct implementation of - OAuth 2.0 Authorization Server Metadata. +This module represents a direct implementation of +OAuth 2.0 Authorization Server Metadata. - https://tools.ietf.org/html/rfc8414 +https://tools.ietf.org/html/rfc8414 """ from .models import AuthorizationServerMetadata from .well_known import get_well_known_url - -__all__ = ['AuthorizationServerMetadata', 'get_well_known_url'] +__all__ = ["AuthorizationServerMetadata", "get_well_known_url"] diff --git a/authlib/oauth2/rfc8414/models.py b/authlib/oauth2/rfc8414/models.py index 3e89a5c98..31d54b465 100644 --- a/authlib/oauth2/rfc8414/models.py +++ b/authlib/oauth2/rfc8414/models.py @@ -1,27 +1,46 @@ -from authlib.common.urls import urlparse, is_valid_url from authlib.common.security import is_secure_transport +from authlib.common.urls import is_valid_url +from authlib.common.urls import urlparse class AuthorizationServerMetadata(dict): """Define Authorization Server Metadata via `Section 2`_ in RFC8414_. + The :meth:`validate` method can compose extension classes via the + ``metadata_classes`` parameter:: + + from authlib.oauth2 import rfc8414, rfc9101 + + metadata = rfc8414.AuthorizationServerMetadata(data) + metadata.validate(metadata_classes=[rfc9101.AuthorizationServerMetadata]) + .. _RFC8414: https://tools.ietf.org/html/rfc8414 .. _`Section 2`: https://tools.ietf.org/html/rfc8414#section-2 """ + REGISTRY_KEYS = [ - 'issuer', 'authorization_endpoint', 'token_endpoint', - 'jwks_uri', 'registration_endpoint', 'scopes_supported', - 'response_types_supported', 'response_modes_supported', - 'grant_types_supported', 'token_endpoint_auth_methods_supported', - 'token_endpoint_auth_signing_alg_values_supported', - 'service_documentation', 'ui_locales_supported', - 'op_policy_uri', 'op_tos_uri', 'revocation_endpoint', - 'revocation_endpoint_auth_methods_supported', - 'revocation_endpoint_auth_signing_alg_values_supported', - 'introspection_endpoint', - 'introspection_endpoint_auth_methods_supported', - 'introspection_endpoint_auth_signing_alg_values_supported', - 'code_challenge_methods_supported', + "issuer", + "authorization_endpoint", + "token_endpoint", + "jwks_uri", + "registration_endpoint", + "scopes_supported", + "response_types_supported", + "response_modes_supported", + "grant_types_supported", + "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg_values_supported", + "service_documentation", + "ui_locales_supported", + "op_policy_uri", + "op_tos_uri", + "revocation_endpoint", + "revocation_endpoint_auth_methods_supported", + "revocation_endpoint_auth_signing_alg_values_supported", + "introspection_endpoint", + "introspection_endpoint_auth_methods_supported", + "introspection_endpoint_auth_signing_alg_values_supported", + "code_challenge_methods_supported", ] def validate_issuer(self): @@ -29,7 +48,7 @@ def validate_issuer(self): a URL that uses the "https" scheme and has no query or fragment components. """ - issuer = self.get('issuer') + issuer = self.get("issuer") #: 1. REQUIRED if not issuer: @@ -50,15 +69,14 @@ def validate_authorization_endpoint(self): [RFC6749]. This is REQUIRED unless no grant types are supported that use the authorization endpoint. """ - url = self.get('authorization_endpoint') + url = self.get("authorization_endpoint") if url: if not is_secure_transport(url): - raise ValueError( - '"authorization_endpoint" MUST use "https" scheme') + raise ValueError('"authorization_endpoint" MUST use "https" scheme') return grant_types_supported = set(self.grant_types_supported) - authorization_grant_types = {'authorization_code', 'implicit'} + authorization_grant_types = {"authorization_code", "implicit"} if grant_types_supported & authorization_grant_types: raise ValueError('"authorization_endpoint" is required') @@ -66,12 +84,15 @@ def validate_token_endpoint(self): """URL of the authorization server's token endpoint [RFC6749]. This is REQUIRED unless only the implicit grant type is supported. """ - grant_types_supported = self.get('grant_types_supported') - if grant_types_supported and len(grant_types_supported) == 1 and \ - grant_types_supported[0] == 'implicit': + grant_types_supported = self.get("grant_types_supported") + if ( + grant_types_supported + and len(grant_types_supported) == 1 + and grant_types_supported[0] == "implicit" + ): return - url = self.get('token_endpoint') + url = self.get("token_endpoint") if not url: raise ValueError('"token_endpoint" is required') @@ -89,7 +110,7 @@ def validate_jwks_uri(self): parameter value is REQUIRED for all keys in the referenced JWK Set to indicate each key's intended usage. """ - url = self.get('jwks_uri') + url = self.get("jwks_uri") if url and not is_secure_transport(url): raise ValueError('"jwks_uri" MUST use "https" scheme') @@ -97,10 +118,9 @@ def validate_registration_endpoint(self): """OPTIONAL. URL of the authorization server's OAuth 2.0 Dynamic Client Registration endpoint [RFC7591]. """ - url = self.get('registration_endpoint') + url = self.get("registration_endpoint") if url and not is_secure_transport(url): - raise ValueError( - '"registration_endpoint" MUST use "https" scheme') + raise ValueError('"registration_endpoint" MUST use "https" scheme') def validate_scopes_supported(self): """RECOMMENDED. JSON array containing a list of the OAuth 2.0 @@ -108,7 +128,7 @@ def validate_scopes_supported(self): Servers MAY choose not to advertise some supported scope values even when this parameter is used. """ - validate_array_value(self, 'scopes_supported') + validate_array_value(self, "scopes_supported") def validate_response_types_supported(self): """REQUIRED. JSON array containing a list of the OAuth 2.0 @@ -117,7 +137,7 @@ def validate_response_types_supported(self): "response_types" parameter defined by "OAuth 2.0 Dynamic Client Registration Protocol" [RFC7591]. """ - response_types_supported = self.get('response_types_supported') + response_types_supported = self.get("response_types_supported") if not response_types_supported: raise ValueError('"response_types_supported" is required') if not isinstance(response_types_supported, list): @@ -131,7 +151,7 @@ def validate_response_modes_supported(self): "fragment"]". The response mode value "form_post" is also defined in "OAuth 2.0 Form Post Response Mode" [OAuth.Post]. """ - validate_array_value(self, 'response_modes_supported') + validate_array_value(self, "response_modes_supported") def validate_grant_types_supported(self): """OPTIONAL. JSON array containing a list of the OAuth 2.0 grant @@ -141,7 +161,7 @@ def validate_grant_types_supported(self): Protocol" [RFC7591]. If omitted, the default value is "["authorization_code", "implicit"]". """ - validate_array_value(self, 'grant_types_supported') + validate_array_value(self, "grant_types_supported") def validate_token_endpoint_auth_methods_supported(self): """OPTIONAL. JSON array containing a list of client authentication @@ -151,7 +171,7 @@ def validate_token_endpoint_auth_methods_supported(self): default is "client_secret_basic" -- the HTTP Basic Authentication Scheme specified in Section 2.3.1 of OAuth 2.0 [RFC6749]. """ - validate_array_value(self, 'token_endpoint_auth_methods_supported') + validate_array_value(self, "token_endpoint_auth_methods_supported") def validate_token_endpoint_auth_signing_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWS signing @@ -166,8 +186,8 @@ def validate_token_endpoint_auth_signing_alg_values_supported(self): """ _validate_alg_values( self, - 'token_endpoint_auth_signing_alg_values_supported', - self.token_endpoint_auth_methods_supported + "token_endpoint_auth_signing_alg_values_supported", + self.token_endpoint_auth_methods_supported, ) def validate_service_documentation(self): @@ -178,7 +198,7 @@ def validate_service_documentation(self): how to register clients needs to be provided in this documentation. """ - value = self.get('service_documentation') + value = self.get("service_documentation") if value and not is_valid_url(value): raise ValueError('"service_documentation" MUST be a URL') @@ -188,7 +208,7 @@ def validate_ui_locales_supported(self): [RFC5646]. If omitted, the set of supported languages and scripts is unspecified. """ - validate_array_value(self, 'ui_locales_supported') + validate_array_value(self, "ui_locales_supported") def validate_op_policy_uri(self): """OPTIONAL. URL that the authorization server provides to the @@ -201,7 +221,7 @@ def validate_op_policy_uri(self): specification is actually referring to a general OAuth 2.0 feature that is not specific to OpenID Connect. """ - value = self.get('op_policy_uri') + value = self.get("op_policy_uri") if value and not is_valid_url(value): raise ValueError('"op_policy_uri" MUST be a URL') @@ -215,14 +235,15 @@ def validate_op_tos_uri(self): specification is actually referring to a general OAuth 2.0 feature that is not specific to OpenID Connect. """ - value = self.get('op_tos_uri') + value = self.get("op_tos_uri") if value and not is_valid_url(value): raise ValueError('"op_tos_uri" MUST be a URL') def validate_revocation_endpoint(self): """OPTIONAL. URL of the authorization server's OAuth 2.0 revocation - endpoint [RFC7009].""" - url = self.get('revocation_endpoint') + endpoint [RFC7009]. + """ + url = self.get("revocation_endpoint") if url and not is_secure_transport(url): raise ValueError('"revocation_endpoint" MUST use "https" scheme') @@ -235,7 +256,7 @@ def validate_revocation_endpoint_auth_methods_supported(self): "client_secret_basic" -- the HTTP Basic Authentication Scheme specified in Section 2.3.1 of OAuth 2.0 [RFC6749]. """ - validate_array_value(self, 'revocation_endpoint_auth_methods_supported') + validate_array_value(self, "revocation_endpoint_auth_methods_supported") def validate_revocation_endpoint_auth_signing_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWS signing @@ -250,18 +271,17 @@ def validate_revocation_endpoint_auth_signing_alg_values_supported(self): """ _validate_alg_values( self, - 'revocation_endpoint_auth_signing_alg_values_supported', - self.revocation_endpoint_auth_methods_supported + "revocation_endpoint_auth_signing_alg_values_supported", + self.revocation_endpoint_auth_methods_supported, ) def validate_introspection_endpoint(self): """OPTIONAL. URL of the authorization server's OAuth 2.0 introspection endpoint [RFC7662]. """ - url = self.get('introspection_endpoint') + url = self.get("introspection_endpoint") if url and not is_secure_transport(url): - raise ValueError( - '"introspection_endpoint" MUST use "https" scheme') + raise ValueError('"introspection_endpoint" MUST use "https" scheme') def validate_introspection_endpoint_auth_methods_supported(self): """OPTIONAL. JSON array containing a list of client authentication @@ -274,7 +294,7 @@ def validate_introspection_endpoint_auth_methods_supported(self): omitted, the set of supported authentication methods MUST be determined by other means. """ - validate_array_value(self, 'introspection_endpoint_auth_methods_supported') + validate_array_value(self, "introspection_endpoint_auth_methods_supported") def validate_introspection_endpoint_auth_signing_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWS signing @@ -289,8 +309,8 @@ def validate_introspection_endpoint_auth_signing_alg_values_supported(self): """ _validate_alg_values( self, - 'introspection_endpoint_auth_signing_alg_values_supported', - self.introspection_endpoint_auth_methods_supported + "introspection_endpoint_auth_signing_alg_values_supported", + self.introspection_endpoint_auth_methods_supported, ) def validate_code_challenge_methods_supported(self): @@ -303,39 +323,63 @@ def validate_code_challenge_methods_supported(self): [IANA.OAuth.Parameters]. If omitted, the authorization server does not support PKCE. """ - validate_array_value(self, 'code_challenge_methods_supported') + validate_array_value(self, "code_challenge_methods_supported") @property def response_modes_supported(self): #: If omitted, the default is ["query", "fragment"] - return self.get('response_modes_supported', ["query", "fragment"]) + return self.get("response_modes_supported", ["query", "fragment"]) @property def grant_types_supported(self): #: If omitted, the default value is ["authorization_code", "implicit"] - return self.get('grant_types_supported', ["authorization_code", "implicit"]) + return self.get("grant_types_supported", ["authorization_code", "implicit"]) @property def token_endpoint_auth_methods_supported(self): #: If omitted, the default is "client_secret_basic" - return self.get('token_endpoint_auth_methods_supported', ["client_secret_basic"]) + return self.get( + "token_endpoint_auth_methods_supported", ["client_secret_basic"] + ) @property def revocation_endpoint_auth_methods_supported(self): #: If omitted, the default is "client_secret_basic" - return self.get('revocation_endpoint_auth_methods_supported', ["client_secret_basic"]) + return self.get( + "revocation_endpoint_auth_methods_supported", ["client_secret_basic"] + ) @property def introspection_endpoint_auth_methods_supported(self): #: If omitted, the set of supported authentication methods MUST be #: determined by other means #: here, we use "client_secret_basic" - return self.get('introspection_endpoint_auth_methods_supported', ["client_secret_basic"]) + return self.get( + "introspection_endpoint_auth_methods_supported", ["client_secret_basic"] + ) - def validate(self): - """Validate all server metadata value.""" + def validate(self, metadata_classes=None): + """Validate all server metadata values. + + :param metadata_classes: Optional list of metadata extension classes + to validate. Example:: + + from authlib.oauth2 import rfc9101 + from authlib.oidc import discovery + + metadata = discovery.OpenIDProviderMetadata(data) + metadata.validate( + metadata_classes=[rfc9101.AuthorizationServerMetadata] + ) + """ for key in self.REGISTRY_KEYS: - object.__getattribute__(self, 'validate_{}'.format(key))() + object.__getattribute__(self, f"validate_{key}")() + + if metadata_classes: + for cls in metadata_classes: + instance = cls(self) + for key in cls.REGISTRY_KEYS: + object.__getattribute__(instance, f"validate_{key}")() def __getattr__(self, key): try: @@ -349,20 +393,26 @@ def __getattr__(self, key): def _validate_alg_values(data, key, auth_methods_supported): value = data.get(key) if value and not isinstance(value, list): - raise ValueError('"{}" MUST be JSON array'.format(key)) + raise ValueError(f'"{key}" MUST be JSON array') auth_methods = set(auth_methods_supported) - jwt_auth_methods = {'private_key_jwt', 'client_secret_jwt'} + jwt_auth_methods = {"private_key_jwt", "client_secret_jwt"} if auth_methods & jwt_auth_methods: if not value: - raise ValueError('"{}" is required'.format(key)) + raise ValueError(f'"{key}" is required') - if value and 'none' in value: - raise ValueError( - 'the value "none" MUST NOT be used in "{}"'.format(key)) + if value and "none" in value: + raise ValueError(f'the value "none" MUST NOT be used in "{key}"') def validate_array_value(metadata, key): values = metadata.get(key) if values is not None and not isinstance(values, list): - raise ValueError('"{}" MUST be JSON array'.format(key)) + raise ValueError(f'"{key}" MUST be JSON array') + + +def validate_boolean_value(metadata, key): + if key not in metadata: + return + if metadata[key] not in (True, False): + raise ValueError(f'"{key}" MUST be boolean') diff --git a/authlib/oauth2/rfc8414/well_known.py b/authlib/oauth2/rfc8414/well_known.py index dc948d883..db5f0faed 100644 --- a/authlib/oauth2/rfc8414/well_known.py +++ b/authlib/oauth2/rfc8414/well_known.py @@ -1,7 +1,7 @@ from authlib.common.urls import urlparse -def get_well_known_url(issuer, external=False, suffix='oauth-authorization-server'): +def get_well_known_url(issuer, external=False, suffix="oauth-authorization-server"): """Get well-known URI with issuer via `Section 3.1`_. .. _`Section 3.1`: https://tools.ietf.org/html/rfc8414#section-3.1 @@ -13,10 +13,10 @@ def get_well_known_url(issuer, external=False, suffix='oauth-authorization-serve """ parsed = urlparse.urlparse(issuer) path = parsed.path - if path and path != '/': - url_path = '/.well-known/{}{}'.format(suffix, path) + if path and path != "/": + url_path = f"/.well-known/{suffix}{path}" else: - url_path = '/.well-known/{}'.format(suffix) + url_path = f"/.well-known/{suffix}" if not external: return url_path - return parsed.scheme + '://' + parsed.netloc + url_path + return parsed.scheme + "://" + parsed.netloc + url_path diff --git a/authlib/oauth2/rfc8628/__init__.py b/authlib/oauth2/rfc8628/__init__.py index 2d4447f85..1a449c48b 100644 --- a/authlib/oauth2/rfc8628/__init__.py +++ b/authlib/oauth2/rfc8628/__init__.py @@ -1,23 +1,28 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc8628 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc8628. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents an implementation of - OAuth 2.0 Device Authorization Grant. +This module represents an implementation of +OAuth 2.0 Device Authorization Grant. - https://tools.ietf.org/html/rfc8628 +https://tools.ietf.org/html/rfc8628 """ +from .device_code import DEVICE_CODE_GRANT_TYPE +from .device_code import DeviceCodeGrant from .endpoint import DeviceAuthorizationEndpoint -from .device_code import DeviceCodeGrant, DEVICE_CODE_GRANT_TYPE -from .models import DeviceCredentialMixin, DeviceCredentialDict -from .errors import AuthorizationPendingError, SlowDownError, ExpiredTokenError - +from .errors import AuthorizationPendingError +from .errors import ExpiredTokenError +from .errors import SlowDownError +from .models import DeviceCredentialDict +from .models import DeviceCredentialMixin __all__ = [ - 'DeviceAuthorizationEndpoint', - 'DeviceCodeGrant', 'DEVICE_CODE_GRANT_TYPE', - 'DeviceCredentialMixin', 'DeviceCredentialDict', - 'AuthorizationPendingError', 'SlowDownError', 'ExpiredTokenError', + "DeviceAuthorizationEndpoint", + "DeviceCodeGrant", + "DEVICE_CODE_GRANT_TYPE", + "DeviceCredentialMixin", + "DeviceCredentialDict", + "AuthorizationPendingError", + "SlowDownError", + "ExpiredTokenError", ] diff --git a/authlib/oauth2/rfc8628/device_code.py b/authlib/oauth2/rfc8628/device_code.py index 67be3365b..a38053bab 100644 --- a/authlib/oauth2/rfc8628/device_code.py +++ b/authlib/oauth2/rfc8628/device_code.py @@ -1,20 +1,17 @@ -import time import logging -from ..rfc6749.errors import ( - InvalidRequestError, - InvalidClientError, - UnauthorizedClientError, - AccessDeniedError, -) -from ..rfc6749 import BaseGrant, TokenEndpointMixin -from .errors import ( - AuthorizationPendingError, - ExpiredTokenError, - SlowDownError, -) + +from ..rfc6749 import BaseGrant +from ..rfc6749 import TokenEndpointMixin +from ..rfc6749.errors import AccessDeniedError +from ..rfc6749.errors import InvalidRequestError +from ..rfc6749.errors import UnauthorizedClientError +from ..rfc6749.hooks import hooked +from .errors import AuthorizationPendingError +from .errors import ExpiredTokenError +from .errors import SlowDownError log = logging.getLogger(__name__) -DEVICE_CODE_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' +DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code" class DeviceCodeGrant(BaseGrant, TokenEndpointMixin): @@ -61,7 +58,9 @@ class DeviceCodeGrant(BaseGrant, TokenEndpointMixin): granted access, an error if they are denied access, or an indication that the client should continue to poll. """ + GRANT_TYPE = DEVICE_CODE_GRANT_TYPE + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] def validate_token_request(self): """After displaying instructions to the user, the client creates an @@ -91,22 +90,21 @@ def validate_token_request(self): &device_code=GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS &client_id=1406020730 """ - device_code = self.request.data.get('device_code') + device_code = self.request.payload.data.get("device_code") if not device_code: - raise InvalidRequestError('Missing "device_code" in payload') + raise InvalidRequestError("Missing 'device_code' in payload") - if not self.request.client_id: - raise InvalidRequestError('Missing "client_id" in payload') + client = self.authenticate_token_endpoint_client() + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError( + f"The client is not authorized to use 'response_type={self.GRANT_TYPE}'", + ) credential = self.query_device_credential(device_code) if not credential: - raise InvalidRequestError('Invalid "device_code" in payload') + raise InvalidRequestError("Invalid 'device_code' in payload") - if credential.get_client_id() != self.request.client_id: - raise UnauthorizedClientError() - - client = self.authenticate_token_endpoint_client() - if not client.check_grant_type(self.GRANT_TYPE): + if credential.get_client_id() != client.get_client_id(): raise UnauthorizedClientError() user = self.validate_device_credential(credential) @@ -114,6 +112,7 @@ def validate_token_request(self): self.request.client = client self.request.credential = credential + @hooked def create_token_response(self): """If the access token request is valid and authorized, the authorization server issues an access token and optional refresh @@ -124,14 +123,16 @@ def create_token_response(self): token = self.generate_token( user=self.request.user, scope=scope, - include_refresh_token=client.check_grant_type('refresh_token'), + include_refresh_token=client.check_grant_type("refresh_token"), ) - log.debug('Issue token %r to %r', token, client) + log.debug("Issue token %r to %r", token, client) self.save_token(token) - self.execute_hook('process_token', token=token) return 200, token, self.TOKEN_RESPONSE_HEADER def validate_device_credential(self, credential): + if credential.is_expired(): + raise ExpiredTokenError() + user_code = credential.get_user_code() user_grant = self.query_user_grant(user_code) @@ -141,31 +142,17 @@ def validate_device_credential(self, credential): raise AccessDeniedError() return user - exp = credential.get_expires_at() - now = time.time() - if exp < now: - raise ExpiredTokenError() - - if self.should_slow_down(credential, now): + if self.should_slow_down(credential): raise SlowDownError() raise AuthorizationPendingError() - def authenticate_token_endpoint_client(self): - client = self.server.query_client(self.request.client_id) - if not client: - raise InvalidClientError() - self.server.send_signal( - 'after_authenticate_client', - client=client, grant=self) - return client - def query_device_credential(self, device_code): """Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. Developers MUST implement it in subclass:: def query_device_credential(self, device_code): - return DeviceCredential.query.get(device_code) + return DeviceCredential.get(device_code) :param device_code: a string represent the code. :return: DeviceCredential instance @@ -178,19 +165,19 @@ def query_user_grant(self, user_code): def query_user_grant(self, user_code): # e.g. we saved user grant info in redis - data = redis.get('oauth_user_grant:' + user_code) + data = redis.get("oauth_user_grant:" + user_code) if not data: return None user_id, allowed = data.split() - user = User.query.get(user_id) + user = User.get(user_id) return user, bool(allowed) Note, user grant information is saved by verification endpoint. """ raise NotImplementedError() - def should_slow_down(self, credential, now): + def should_slow_down(self, credential): """The authorization request is still pending and polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests. diff --git a/authlib/oauth2/rfc8628/endpoint.py b/authlib/oauth2/rfc8628/endpoint.py index fda5f1a3f..555715d45 100644 --- a/authlib/oauth2/rfc8628/endpoint.py +++ b/authlib/oauth2/rfc8628/endpoint.py @@ -1,10 +1,9 @@ -from authlib.consts import default_json_headers from authlib.common.security import generate_token from authlib.common.urls import add_params_to_uri -from ..rfc6749.errors import InvalidRequestError +from authlib.consts import default_json_headers -class DeviceAuthorizationEndpoint(object): +class DeviceAuthorizationEndpoint: """This OAuth 2.0 [RFC6749] protocol extension enables OAuth clients to request user authorization from applications on devices that have limited input capabilities or lack a suitable browser. Such devices @@ -45,10 +44,11 @@ class DeviceAuthorizationEndpoint(object): code and provides the end-user verification URI. """ - ENDPOINT_NAME = 'device_authorization' + ENDPOINT_NAME = "device_authorization" + CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] #: customize "user_code" type, string or digital - USER_CODE_TYPE = 'string' + USER_CODE_TYPE = "string" #: The lifetime in seconds of the "device_code" and "user_code" EXPIRES_IN = 1800 @@ -68,29 +68,55 @@ def __call__(self, request): def create_endpoint_request(self, request): return self.server.create_oauth2_request(request) + def authenticate_client(self, request): + """client_id is REQUIRED **if the client is not** authenticating with the + authorization server as described in Section 3.2.1. of [RFC6749]. + + This means the endpoint support "none" authentication method. In this case, + this endpoint's auth methods are: + + - client_secret_basic + - client_secret_post + - none + + Developers change the value of ``CLIENT_AUTH_METHODS`` in subclass. For + instance:: + + class MyDeviceAuthorizationEndpoint(DeviceAuthorizationEndpoint): + # only support ``client_secret_basic`` auth method + CLIENT_AUTH_METHODS = ["client_secret_basic"] + """ + client = self.server.authenticate_client( + request, self.CLIENT_AUTH_METHODS, self.ENDPOINT_NAME + ) + request.client = client + return client + def create_endpoint_response(self, request): # https://tools.ietf.org/html/rfc8628#section-3.1 - if not request.client_id: - raise InvalidRequestError('Missing "client_id" in payload') - self.server.validate_requested_scope(request.scope) + self.authenticate_client(request) + self.server.validate_requested_scope(request.payload.scope) device_code = self.generate_device_code() user_code = self.generate_user_code() verification_uri = self.get_verification_uri() verification_uri_complete = add_params_to_uri( - verification_uri, [('user_code', user_code)]) + verification_uri, [("user_code", user_code)] + ) data = { - 'device_code': device_code, - 'user_code': user_code, - 'verification_uri': verification_uri, - 'verification_uri_complete': verification_uri_complete, - 'expires_in': self.EXPIRES_IN, - 'interval': self.INTERVAL, + "device_code": device_code, + "user_code": user_code, + "verification_uri": verification_uri, + "verification_uri_complete": verification_uri_complete, + "expires_in": self.EXPIRES_IN, + "interval": self.INTERVAL, } - self.save_device_credential(request.client_id, request.scope, data) + self.save_device_credential( + request.payload.client_id, request.payload.scope, data + ) return 200, data, default_json_headers def generate_user_code(self): @@ -99,7 +125,7 @@ def generate_user_code(self): Developers can rewrite this method to create their own ``user_code``. """ # https://tools.ietf.org/html/rfc8628#section-6.1 - if self.USER_CODE_TYPE == 'digital': + if self.USER_CODE_TYPE == "digital": return create_digital_user_code() return create_string_user_code() @@ -115,7 +141,7 @@ def get_verification_uri(self): Developers MUST implement this method in subclass:: def get_verification_uri(self): - return 'https://your-company.com/active' + return "https://your-company.com/active" """ raise NotImplementedError() @@ -124,25 +150,23 @@ def save_device_credential(self, client_id, scope, data): implement this method in subclass:: def save_device_credential(self, client_id, scope, data): - item = DeviceCredential( - client_id=client_id, - scope=scope, - **data - ) + item = DeviceCredential(client_id=client_id, scope=scope, **data) item.save() """ raise NotImplementedError() def create_string_user_code(): - base = 'BCDFGHJKLMNPQRSTVWXZ' - return '-'.join([generate_token(4, base), generate_token(4, base)]) + base = "BCDFGHJKLMNPQRSTVWXZ" + return "-".join([generate_token(4, base), generate_token(4, base)]) def create_digital_user_code(): - base = '0123456789' - return '-'.join([ - generate_token(3, base), - generate_token(3, base), - generate_token(3, base), - ]) + base = "0123456789" + return "-".join( + [ + generate_token(3, base), + generate_token(3, base), + generate_token(3, base), + ] + ) diff --git a/authlib/oauth2/rfc8628/errors.py b/authlib/oauth2/rfc8628/errors.py index 4a63db825..354306dc6 100644 --- a/authlib/oauth2/rfc8628/errors.py +++ b/authlib/oauth2/rfc8628/errors.py @@ -7,7 +7,8 @@ class AuthorizationPendingError(OAuth2Error): """The authorization request is still pending as the end user hasn't yet completed the user-interaction steps (Section 3.3). """ - error = 'authorization_pending' + + error = "authorization_pending" class SlowDownError(OAuth2Error): @@ -15,7 +16,8 @@ class SlowDownError(OAuth2Error): still pending and polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests. """ - error = 'slow_down' + + error = "slow_down" class ExpiredTokenError(OAuth2Error): @@ -24,4 +26,5 @@ class ExpiredTokenError(OAuth2Error): authorization request but SHOULD wait for user interaction before restarting to avoid unnecessary polling. """ - error = 'expired_token' + + error = "expired_token" diff --git a/authlib/oauth2/rfc8628/models.py b/authlib/oauth2/rfc8628/models.py index 4090ed674..1127ad4a3 100644 --- a/authlib/oauth2/rfc8628/models.py +++ b/authlib/oauth2/rfc8628/models.py @@ -1,5 +1,7 @@ +import time -class DeviceCredentialMixin(object): + +class DeviceCredentialMixin: def get_client_id(self): raise NotImplementedError() @@ -9,19 +11,28 @@ def get_scope(self): def get_user_code(self): raise NotImplementedError() - def get_expires_at(self): + def is_expired(self): raise NotImplementedError() class DeviceCredentialDict(dict, DeviceCredentialMixin): def get_client_id(self): - return self['client_id'] + return self["client_id"] def get_scope(self): - return self.get('scope') + return self.get("scope") def get_user_code(self): - return self['user_code'] + return self["user_code"] + + def get_nonce(self): + return self.get("nonce") + + def get_auth_time(self): + return self.get("auth_time") - def get_expires_at(self): - return self.get('expires_at') + def is_expired(self): + expires_at = self.get("expires_at") + if expires_at is not None: + return expires_at < time.time() + return False diff --git a/authlib/oauth2/rfc8693/__init__.py b/authlib/oauth2/rfc8693/__init__.py index 110b3874f..8ea6c5f69 100644 --- a/authlib/oauth2/rfc8693/__init__.py +++ b/authlib/oauth2/rfc8693/__init__.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- -""" - authlib.oauth2.rfc8693 - ~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oauth2.rfc8693. +~~~~~~~~~~~~~~~~~~~~~~ - This module represents an implementation of - OAuth 2.0 Token Exchange. +This module represents an implementation of +OAuth 2.0 Token Exchange. - https://tools.ietf.org/html/rfc8693 +https://tools.ietf.org/html/rfc8693 """ diff --git a/authlib/oauth2/rfc9068/__init__.py b/authlib/oauth2/rfc9068/__init__.py new file mode 100644 index 000000000..2d1d87d86 --- /dev/null +++ b/authlib/oauth2/rfc9068/__init__.py @@ -0,0 +1,11 @@ +from .introspection import JWTIntrospectionEndpoint +from .revocation import JWTRevocationEndpoint +from .token import JWTBearerTokenGenerator +from .token_validator import JWTBearerTokenValidator + +__all__ = [ + "JWTBearerTokenGenerator", + "JWTBearerTokenValidator", + "JWTIntrospectionEndpoint", + "JWTRevocationEndpoint", +] diff --git a/authlib/oauth2/rfc9068/claims.py b/authlib/oauth2/rfc9068/claims.py new file mode 100644 index 000000000..641a63940 --- /dev/null +++ b/authlib/oauth2/rfc9068/claims.py @@ -0,0 +1,35 @@ +from joserfc.errors import InvalidClaimError +from joserfc.jwt import JWTClaimsRegistry + +from authlib.oauth2.claims import JWTClaims + + +class JWTAccessTokenClaimsValidator(JWTClaimsRegistry): + def validate_auth_time(self, auth_time): + if not isinstance(auth_time, (int, float)): + raise InvalidClaimError("auth_time") + self.check_value("auth_time", auth_time) + + def validate_amr(self, amr): + if not isinstance(amr, list): + raise InvalidClaimError("amr", amr) + + +class JWTAccessTokenClaims(JWTClaims): + registry_cls = JWTAccessTokenClaimsValidator + REGISTERED_CLAIMS = JWTClaims.REGISTERED_CLAIMS + [ + "client_id", + "auth_time", + "acr", + "amr", + "scope", + "groups", + "roles", + "entitlements", + ] + + def validate(self, **kwargs): + typ = self.header.get("typ") + if typ and typ.lower() not in ("at+jwt", "application/at+jwt"): + raise InvalidClaimError("typ") + super().validate(**kwargs) diff --git a/authlib/oauth2/rfc9068/introspection.py b/authlib/oauth2/rfc9068/introspection.py new file mode 100644 index 000000000..85fda35fc --- /dev/null +++ b/authlib/oauth2/rfc9068/introspection.py @@ -0,0 +1,129 @@ +from joserfc.errors import ExpiredTokenError +from joserfc.errors import InvalidClaimError + +from authlib.common.errors import ContinueIteration +from authlib.consts import default_json_headers +from authlib.oauth2.rfc6750.errors import InvalidTokenError + +from ..rfc7662 import IntrospectionEndpoint +from .claims import JWTAccessTokenClaims +from .token_validator import JWTBearerTokenValidator + + +class JWTIntrospectionEndpoint(IntrospectionEndpoint): + r"""JWTIntrospectionEndpoint inherits from :ref:`specs/rfc7662` + :class:`~authlib.oauth2.rfc7662.IntrospectionEndpoint` and implements the machinery + to automatically process the JWT access tokens. + + :param issuer: The issuer identifier for which tokens will be introspected. + + :param \\*\\*kwargs: Other parameters are inherited from + :class:`~authlib.oauth2.rfc7662.introspection.IntrospectionEndpoint`. + + :: + + class MyJWTAccessTokenIntrospectionEndpoint(JWTIntrospectionEndpoint): + def get_jwks(self): ... + + def get_username(self, user_id): ... + + + # endpoint dedicated to JWT access token introspection + authorization_server.register_endpoint( + MyJWTAccessTokenIntrospectionEndpoint( + issuer="https://authorization-server.example.org", + ) + ) + + # another endpoint dedicated to refresh token introspection + authorization_server.register_endpoint(MyRefreshTokenIntrospectionEndpoint) + + """ + + #: Endpoint name to be registered + ENDPOINT_NAME = "introspection" + + def __init__(self, issuer, server=None, *args, **kwargs): + super().__init__(*args, server=server, **kwargs) + self.issuer = issuer + + def create_endpoint_response(self, request): + """""" + # The authorization server first validates the client credentials + client = self.authenticate_endpoint_client(request) + + # then verifies whether the token was issued to the client making + # the revocation request + token = self.authenticate_token(request, client) + + # the authorization server invalidates the token + body = self.create_introspection_payload(token) + return 200, body, default_json_headers + + def authenticate_token(self, request, client): + """""" + self.check_params(request, client) + + # do not attempt to decode refresh_tokens + if request.form.get("token_type_hint") not in ("access_token", None): + raise ContinueIteration() + + validator = JWTBearerTokenValidator(issuer=self.issuer, resource_server=None) + validator.get_jwks = self.get_jwks + try: + token = validator.authenticate_token(request.form["token"]) + + # if the token is not a JWT, fall back to the regular flow + except InvalidTokenError as exc: + raise ContinueIteration() from exc + + if token and self.check_permission(token, client, request): + return token + + def create_introspection_payload(self, token: JWTAccessTokenClaims): + if not token: + return {"active": False} + + try: + token.validate() + except ExpiredTokenError: + return {"active": False} + except InvalidClaimError as exc: + if exc.claim == "iss": + raise ContinueIteration() from exc + raise InvalidTokenError() from exc + + payload = { + "active": True, + "token_type": "Bearer", + "client_id": token["client_id"], + "scope": token["scope"], + "sub": token["sub"], + "aud": token["aud"], + "iss": token["iss"], + "exp": token["exp"], + "iat": token["iat"], + } + + if username := self.get_username(token["sub"]): + payload["username"] = username + + return payload + + def get_jwks(self): + """Return the JWKs that will be used to check the JWT access token signature. + Developers MUST re-implement this method:: + + def get_jwks(self): + return load_jwks("jwks.json") + """ + raise NotImplementedError() + + def get_username(self, user_id: str) -> str: + """Returns an username from a user ID. + Developers MAY re-implement this method:: + + def get_username(self, user_id): + return User.get(id=user_id).username + """ + return None diff --git a/authlib/oauth2/rfc9068/revocation.py b/authlib/oauth2/rfc9068/revocation.py new file mode 100644 index 000000000..62e45c2c1 --- /dev/null +++ b/authlib/oauth2/rfc9068/revocation.py @@ -0,0 +1,74 @@ +from authlib.common.errors import ContinueIteration +from authlib.oauth2.rfc6750.errors import InvalidTokenError +from authlib.oauth2.rfc9068.token_validator import JWTBearerTokenValidator + +from ..rfc6749 import UnsupportedTokenTypeError +from ..rfc7009 import RevocationEndpoint + + +class JWTRevocationEndpoint(RevocationEndpoint): + r"""JWTRevocationEndpoint inherits from `RFC7009`_ + :class:`~authlib.oauth2.rfc7009.RevocationEndpoint`. + + The JWT access tokens cannot be revoked. + If the submitted token is a JWT access token, then revocation returns + a `invalid_token_error`. + + :param issuer: The issuer identifier. + + :param \\*\\*kwargs: Other parameters are inherited from + :class:`~authlib.oauth2.rfc7009.RevocationEndpoint`. + + Plain text access tokens and other kind of tokens such as refresh_tokens + will be ignored by this endpoint and passed to the next revocation endpoint:: + + class MyJWTAccessTokenRevocationEndpoint(JWTRevocationEndpoint): + def get_jwks(self): ... + + + # endpoint dedicated to JWT access token revokation + authorization_server.register_endpoint( + MyJWTAccessTokenRevocationEndpoint( + issuer="https://authorization-server.example.org", + ) + ) + + # another endpoint dedicated to refresh token revokation + authorization_server.register_endpoint(MyRefreshTokenRevocationEndpoint) + + .. _RFC7009: https://tools.ietf.org/html/rfc7009 + """ + + def __init__(self, issuer, server=None, *args, **kwargs): + super().__init__(*args, server=server, **kwargs) + self.issuer = issuer + + def authenticate_token(self, request, client): + """""" + self.check_params(request, client) + + # do not attempt to revoke refresh_tokens + if request.form.get("token_type_hint") not in ("access_token", None): + raise ContinueIteration() + + validator = JWTBearerTokenValidator(issuer=self.issuer, resource_server=None) + validator.get_jwks = self.get_jwks + + try: + validator.authenticate_token(request.form["token"]) + + # if the token is not a JWT, fall back to the regular flow + except InvalidTokenError as exc: + raise ContinueIteration() from exc + + # JWT access token cannot be revoked + raise UnsupportedTokenTypeError() + + def get_jwks(self): + """Return the JWKs that will be used to check the JWT access token signature. + Developers MUST re-implement this method:: + + def get_jwks(self): + return load_jwks("jwks.json") + """ + raise NotImplementedError() diff --git a/authlib/oauth2/rfc9068/token.py b/authlib/oauth2/rfc9068/token.py new file mode 100644 index 000000000..97959a1b0 --- /dev/null +++ b/authlib/oauth2/rfc9068/token.py @@ -0,0 +1,217 @@ +import time + +from joserfc import jwt + +from authlib._joserfc_helpers import import_any_key +from authlib.common.security import generate_token +from authlib.oauth2.rfc6750.token import BearerTokenGenerator + + +class JWTBearerTokenGenerator(BearerTokenGenerator): + r"""A JWT formatted access token generator. + + :param issuer: The issuer identifier. Will appear in the JWT ``iss`` claim. + + :param \\*\\*kwargs: Other parameters are inherited from + :class:`~authlib.oauth2.rfc6750.token.BearerTokenGenerator`. + + This token generator can be registered into the authorization server:: + + class MyJWTBearerTokenGenerator(JWTBearerTokenGenerator): + def get_jwks(self): ... + + def get_extra_claims(self, client, grant_type, user, scope): ... + + + authorization_server.register_token_generator( + "default", + MyJWTBearerTokenGenerator( + issuer="https://authorization-server.example.org" + ), + ) + """ + + def __init__( + self, + issuer, + alg="RS256", + refresh_token_generator=None, + expires_generator=None, + ): + super().__init__( + self.access_token_generator, refresh_token_generator, expires_generator + ) + self.issuer = issuer + self.alg = alg + + def get_jwks(self): + """Return the JWKs that will be used to sign the JWT access token. + Developers MUST re-implement this method:: + + def get_jwks(self): + return load_jwks("jwks.json") + """ + raise NotImplementedError() + + def get_extra_claims(self, client, grant_type, user, scope): + """Return extra claims to add in the JWT access token. Developers MAY + re-implement this method to add identity claims like the ones in + :ref:`specs/oidc` ID Token, or any other arbitrary claims:: + + def get_extra_claims(self, client, grant_type, user, scope): + return generate_user_info(user, scope) + """ + return {} + + def get_audiences(self, client, user, scope) -> str | list[str]: + """Return the audience for the token. By default this simply returns + the client ID. Developers MAY re-implement this method to add extra + audiences:: + + def get_audiences(self, client, user, scope): + return [ + client.get_client_id(), + resource_server.get_id(), + ] + """ + return client.get_client_id() + + def get_acr(self, user) -> str | None: + """Authentication Context Class Reference. + Returns a user-defined case sensitive string indicating the class of + authentication the used performed. Token audience may refuse to give access to + some resources if some ACR criteria are not met. + :ref:`specs/oidc` defines one special value: ``0`` means that the user + authentication did not respect `ISO29115`_ level 1, and will be refused monetary + operations. Developers MAY re-implement this method:: + + def get_acr(self, user): + if user.insecure_session(): + return "0" + return "urn:mace:incommon:iap:silver" + + .. _ISO29115: https://www.iso.org/standard/45138.html + """ + return None + + def get_auth_time(self, user) -> int | None: + """User authentication time. + Time when the End-User authentication occurred. Its value is a JSON number + representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC + until the date/time. Developers MAY re-implement this method:: + + def get_auth_time(self, user): + return datetime.timestamp(user.get_auth_time()) + """ + return None + + def get_amr(self, user) -> list[str] | None: + """Authentication Methods References. + Defined by :ref:`specs/oidc` as an option list of user-defined case-sensitive + strings indication which authentication methods have been used to authenticate + the user. Developers MAY re-implement this method:: + + def get_amr(self, user): + return ["2FA"] if user.has_2fa_enabled() else [] + """ + return None + + def get_jti(self, client, grant_type, user, scope) -> str: + """JWT ID. + Create an unique identifier for the token. Developers MAY re-implement + this method:: + + def get_jti(self, client, grant_type, user scope): + return generate_random_string(16) + """ + return generate_token(16) + + def access_token_generator(self, client, grant_type, user, scope): + now = int(time.time()) + expires_in = now + self._get_expires_in(client, grant_type) + + token_data = { + "iss": self.issuer, + "exp": expires_in, + "client_id": client.get_client_id(), + "iat": now, + "jti": self.get_jti(client, grant_type, user, scope), + "scope": scope, + } + + # In cases of access tokens obtained through grants where a resource owner is + # involved, such as the authorization code grant, the value of 'sub' SHOULD + # correspond to the subject identifier of the resource owner. + + if user: + token_data["sub"] = user.get_user_id() + + # In cases of access tokens obtained through grants where no resource owner is + # involved, such as the client credentials grant, the value of 'sub' SHOULD + # correspond to an identifier the authorization server uses to indicate the + # client application. + + else: + token_data["sub"] = client.get_client_id() + + # If the request includes a 'resource' parameter (as defined in [RFC8707]), the + # resulting JWT access token 'aud' claim SHOULD have the same value as the + # 'resource' parameter in the request. + + # TODO: Implement this with RFC8707 + if False: # pragma: no cover + ... + + # If the request does not include a 'resource' parameter, the authorization + # server MUST use a default resource indicator in the 'aud' claim. If a 'scope' + # parameter is present in the request, the authorization server SHOULD use it to + # infer the value of the default resource indicator to be used in the 'aud' + # claim. The mechanism through which scopes are associated with default resource + # indicator values is outside the scope of this specification. + + else: + token_data["aud"] = self.get_audiences(client, user, scope) + + # If the values in the 'scope' parameter refer to different default resource + # indicator values, the authorization server SHOULD reject the request with + # 'invalid_scope' as described in Section 4.1.2.1 of [RFC6749]. + # TODO: Implement this with RFC8707 + + if auth_time := self.get_auth_time(user): + token_data["auth_time"] = auth_time + + # The meaning and processing of acr Claim Values is out of scope for this + # specification. + + if acr := self.get_acr(user): + token_data["acr"] = acr + + # The definition of particular values to be used in the amr Claim is beyond the + # scope of this specification. + + if amr := self.get_amr(user): + token_data["amr"] = amr + + # Authorization servers MAY return arbitrary attributes not defined in any + # existing specification, as long as the corresponding claim names are collision + # resistant or the access tokens are meant to be used only within a private + # subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details. + + token_data.update(self.get_extra_claims(client, grant_type, user, scope)) + + # This specification registers the 'application/at+jwt' media type, which can + # be used to indicate that the content is a JWT access token. JWT access tokens + # MUST include this media type in the 'typ' header parameter to explicitly + # declare that the JWT represents an access token complying with this profile. + # Per the definition of 'typ' in Section 4.1.9 of [RFC7515], it is RECOMMENDED + # that the 'application/' prefix be omitted. Therefore, the 'typ' value used + # SHOULD be 'at+jwt'. + + header = {"alg": self.alg, "typ": "at+jwt"} + key = import_any_key(self.get_jwks()) + access_token = jwt.encode( + header, + token_data, + key=key, + ) + return access_token diff --git a/authlib/oauth2/rfc9068/token_validator.py b/authlib/oauth2/rfc9068/token_validator.py new file mode 100644 index 000000000..dc3e7b80e --- /dev/null +++ b/authlib/oauth2/rfc9068/token_validator.py @@ -0,0 +1,168 @@ +"""authlib.oauth2.rfc9068.token_validator. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Implementation of Validating JWT Access Tokens per `Section 4`_. + +.. _`Section 7`: https://www.rfc-editor.org/rfc/rfc9068.html#name-validating-jwt-access-token +""" + +from joserfc import jwt +from joserfc.errors import DecodeError +from joserfc.errors import JoseError + +from authlib._joserfc_helpers import import_any_key +from authlib.oauth2.claims import ClaimsOption +from authlib.oauth2.rfc6750.errors import InsufficientScopeError +from authlib.oauth2.rfc6750.errors import InvalidTokenError +from authlib.oauth2.rfc6750.validator import BearerTokenValidator + +from .claims import JWTAccessTokenClaims + + +class JWTBearerTokenValidator(BearerTokenValidator): + """JWTBearerTokenValidator can protect your resource server endpoints. + + :param issuer: The issuer from which tokens will be accepted. + :param resource_server: An identifier for the current resource server, + which must appear in the JWT ``aud`` claim. + + Developers needs to implement the missing methods:: + + class MyJWTBearerTokenValidator(JWTBearerTokenValidator): + def get_jwks(self): ... + + + require_oauth = ResourceProtector() + require_oauth.register_token_validator( + MyJWTBearerTokenValidator( + issuer="https://authorization-server.example.org", + resource_server="https://resource-server.example.org", + ) + ) + + You can then protect resources depending on the JWT `scope`, `groups`, + `roles` or `entitlements` claims:: + + @require_oauth( + scope="profile", + groups="admins", + roles="student", + entitlements="captain", + ) + def resource_endpoint(): ... + """ + + def __init__(self, issuer, resource_server, *args, **kwargs): + self.issuer = issuer + self.resource_server = resource_server + super().__init__(*args, **kwargs) + + def get_jwks(self): + """Return the JWKs that will be used to check the JWT access token signature. + Developers MUST re-implement this method. Typically the JWKs are statically + stored in the resource server configuration, or dynamically downloaded and + cached using :ref:`specs/rfc8414`:: + + def get_jwks(self): + if "jwks" in cache: + return cache.get("jwks") + + server_metadata = get_server_metadata(self.issuer) + jwks_uri = server_metadata.get("jwks_uri") + cache["jwks"] = requests.get(jwks_uri).json() + return cache["jwks"] + """ + raise NotImplementedError() + + def validate_iss(self, claims, iss: "str") -> bool: + # The issuer identifier for the authorization server (which is typically + # obtained during discovery) MUST exactly match the value of the 'iss' + # claim. + return iss == self.issuer + + def authenticate_token(self, token_string): + """""" + # empty docstring avoids to display the irrelevant parent docstring + + claims_options: dict[str, ClaimsOption] = { + "iss": {"essential": True, "validate": self.validate_iss}, + "exp": {"essential": True}, + "aud": {"essential": True, "value": self.resource_server}, + "sub": {"essential": True}, + "client_id": {"essential": True}, + "iat": {"essential": True}, + "jti": {"essential": True}, + "auth_time": {"essential": False}, + "acr": {"essential": False}, + "amr": {"essential": False}, + "scope": {"essential": False}, + "groups": {"essential": False}, + "roles": {"essential": False}, + "entitlements": {"essential": False}, + } + key = import_any_key(self.get_jwks()) + + # If the JWT access token is encrypted, decrypt it using the keys and algorithms + # that the resource server specified during registration. If encryption was + # negotiated with the authorization server at registration time and the incoming + # JWT access token is not encrypted, the resource server SHOULD reject it. + + # The resource server MUST validate the signature of all incoming JWT access + # tokens according to [RFC7515] using the algorithm specified in the JWT 'alg' + # Header Parameter. The resource server MUST reject any JWT in which the value + # of 'alg' is 'none'. The resource server MUST use the keys provided by the + # authorization server. + try: + token = jwt.decode(token_string, key=key) + return JWTAccessTokenClaims(token.claims, token.header, claims_options) + except DecodeError as exc: + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) from exc + + def validate_token( + self, + token: JWTAccessTokenClaims, + scopes, + request, + groups=None, + roles=None, + entitlements=None, + ): + """""" + # empty docstring avoids to display the irrelevant parent docstring + try: + token.validate() + except JoseError as exc: + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) from exc + + # If an authorization request includes a scope parameter, the corresponding + # issued JWT access token SHOULD include a 'scope' claim as defined in Section + # 4.2 of [RFC8693]. All the individual scope strings in the 'scope' claim MUST + # have meaning for the resources indicated in the 'aud' claim. See Section 5 for + # more considerations about the relationship between scope strings and resources + # indicated by the 'aud' claim. + + if self.scope_insufficient(token.get("scope", []), scopes): + raise InsufficientScopeError() + + # Many authorization servers embed authorization attributes that go beyond the + # delegated scenarios described by [RFC7519] in the access tokens they issue. + # Typical examples include resource owner memberships in roles and groups that + # are relevant to the resource being accessed, entitlements assigned to the + # resource owner for the targeted resource that the authorization server knows + # about, and so on. An authorization server wanting to include such attributes + # in a JWT access token SHOULD use the 'groups', 'roles', and 'entitlements' + # attributes of the 'User' resource schema defined by Section 4.1.2 of + # [RFC7643]) as claim types. + + if self.scope_insufficient(token.get("groups"), groups): + raise InvalidTokenError() + + if self.scope_insufficient(token.get("roles"), roles): + raise InvalidTokenError() + + if self.scope_insufficient(token.get("entitlements"), entitlements): + raise InvalidTokenError() diff --git a/authlib/oauth2/rfc9101/__init__.py b/authlib/oauth2/rfc9101/__init__.py new file mode 100644 index 000000000..02194770d --- /dev/null +++ b/authlib/oauth2/rfc9101/__init__.py @@ -0,0 +1,9 @@ +from .authorization_server import JWTAuthenticationRequest +from .discovery import AuthorizationServerMetadata +from .registration import ClientMetadataClaims + +__all__ = [ + "AuthorizationServerMetadata", + "JWTAuthenticationRequest", + "ClientMetadataClaims", +] diff --git a/authlib/oauth2/rfc9101/authorization_server.py b/authlib/oauth2/rfc9101/authorization_server.py new file mode 100644 index 000000000..09c4f8c41 --- /dev/null +++ b/authlib/oauth2/rfc9101/authorization_server.py @@ -0,0 +1,279 @@ +from joserfc import jwt +from joserfc.errors import DecodeError +from joserfc.errors import JoseError +from joserfc.errors import UnsupportedAlgorithmError +from joserfc.jws import JWSRegistry + +from authlib._joserfc_helpers import import_any_key + +from ..rfc6749 import AuthorizationServer +from ..rfc6749 import ClientMixin +from ..rfc6749 import InvalidClientError +from ..rfc6749 import InvalidRequestError +from ..rfc6749.requests import BasicOAuth2Payload +from ..rfc6749.requests import OAuth2Request +from .errors import InvalidRequestObjectError +from .errors import InvalidRequestUriError +from .errors import RequestNotSupportedError +from .errors import RequestUriNotSupportedError + + +class JWTAuthenticationRequest: + """Authorization server extension implementing the support + for JWT secured authentication request, as defined in :rfc:`RFC9101 <9101>`. + + :param support_request: Whether to enable support for the ``request`` parameter. + :param support_request_uri: Whether to enable support for the ``request_uri`` parameter. + + This extension is intended to be inherited and registered into the authorization server:: + + class JWTAuthenticationRequest(rfc9101.JWTAuthenticationRequest): + def resolve_client_public_key(self, client: ClientMixin): + return get_jwks_for_client(client) + + def get_request_object(self, request_uri: str): + try: + return requests.get(request_uri).text + except requests.Exception: + return None + + def get_server_metadata(self): + return { + "issuer": ..., + "authorization_endpoint": ..., + "require_signed_request_object": ..., + } + + def get_client_require_signed_request_object(self, client: ClientMixin): + return client.require_signed_request_object + + + authorization_server.register_extension(JWTAuthenticationRequest()) + """ + + claims_validator = jwt.JWTClaimsRegistry( + client_id={"essential": True}, + ) + + def __init__(self, support_request: bool = True, support_request_uri: bool = True): + self.support_request = support_request + self.support_request_uri = support_request_uri + + def __call__(self, authorization_server: AuthorizationServer): + authorization_server.register_hook( + "before_get_authorization_grant", self.parse_authorization_request + ) + + def get_request_object_signing_algorithms(self, client): + """Return the supported algorithms for verifying the ``request_object`` JWT signature. + By default, this method will only return the recommended algorithms. If signed request + object is not required, "none" algorithm will be included. + + Developers can override this method to customize the supported algorithms:: + + def get_request_object_signing_algorithms(self, client): + return ["RS256"] + """ + metadata = self.get_server_metadata() + algorithms = metadata.get("request_object_signing_alg_values_supported") + if not algorithms: + require_signed1 = self.get_client_require_signed_request_object(client) + require_signed2 = metadata.get("require_signed_request_object", False) + if require_signed1 or require_signed2: + algorithms = JWSRegistry.recommended + else: + algorithms = [*JWSRegistry.recommended, "none"] + return algorithms + + def parse_authorization_request( + self, authorization_server: AuthorizationServer, request: OAuth2Request + ): + client_id = request.payload.client_id + if client_id is None: + raise InvalidClientError( + status_code=404, + description="Missing 'client_id' parameter.", + ) + + client = authorization_server.query_client(client_id) + if not client: + raise InvalidClientError( + status_code=404, + description="The client does not exist on this server.", + ) + + if not self._shoud_proceed_with_request_object(request, client): + return + + raw_request_object = self._get_raw_request_object(request) + request_object = self._decode_request_object( + request, client, raw_request_object + ) + payload = BasicOAuth2Payload(request_object.claims) + request.payload = payload + + def _shoud_proceed_with_request_object( + self, + request: OAuth2Request, + client: ClientMixin, + ) -> bool: + if "request" in request.payload.data and "request_uri" in request.payload.data: + raise InvalidRequestError( + "The 'request' and 'request_uri' parameters are mutually exclusive.", + state=request.payload.state, + ) + + if "request" in request.payload.data: + if not self.support_request: + raise RequestNotSupportedError(state=request.payload.state) + return True + + if "request_uri" in request.payload.data: + if not self.support_request_uri: + raise RequestUriNotSupportedError(state=request.payload.state) + return True + + # When the value of it [require_signed_request_object] as client metadata is true, + # then the server MUST reject the authorization request + # from the client that does not conform to this specification. + if self.get_client_require_signed_request_object(client): + raise InvalidRequestError( + "Authorization requests for this client must use signed request objects.", + state=request.payload.state, + ) + + # When the value of it [require_signed_request_object] as server metadata is true, + # then the server MUST reject the authorization request + # from any client that does not conform to this specification. + metadata = self.get_server_metadata() + if metadata and metadata.get("require_signed_request_object", False): + raise InvalidRequestError( + "Authorization requests for this server must use signed request objects.", + state=request.payload.state, + ) + + return False + + def _get_raw_request_object(self, request: OAuth2Request) -> str: + if "request_uri" in request.payload.data: + raw_request_object = self.get_request_object( + request.payload.data["request_uri"] + ) + if not raw_request_object: + raise InvalidRequestUriError(state=request.payload.state) + + else: + raw_request_object = request.payload.data["request"] + + return raw_request_object + + def _decode_request_object( + self, request, client: ClientMixin, raw_request_object: str + ): + jwks = self.resolve_client_public_key(client) + key = import_any_key(jwks) + algorithms = self.get_request_object_signing_algorithms(client) + + try: + request_object = jwt.decode(raw_request_object, key, algorithms=algorithms) + self.claims_validator.validate(request_object.claims) + except UnsupportedAlgorithmError as error: + raise InvalidRequestError( + "Authorization requests must be signed with supported algorithms.", + state=request.payload.state, + ) from error + except DecodeError as error: + raise InvalidRequestObjectError(state=request.payload.state) from error + except JoseError as error: + raise InvalidRequestObjectError( + description=error.description or InvalidRequestObjectError.description, + state=request.payload.state, + ) from error + + # The client ID values in the client_id request parameter and in + # the Request Object client_id claim MUST be identical. + if request_object.claims["client_id"] != request.payload.client_id: + raise InvalidRequestError( + "The 'client_id' claim from the request parameters " + "and the request object claims don't match.", + state=request.payload.state, + ) + + # The Request Object MAY be sent by value, as described in Section 5.1, + # or by reference, as described in Section 5.2. request and + # request_uri parameters MUST NOT be included in Request Objects. + if "request" in request_object.claims or "request_uri" in request_object.claims: + raise InvalidRequestError( + "The 'request' and 'request_uri' parameters must not be included in the request object.", + state=request.payload.state, + ) + + return request_object + + def get_request_object(self, request_uri: str): + """Download the request object at ``request_uri``. + + This method must be implemented if the ``request_uri`` parameter is supported:: + + class JWTAuthenticationRequest(rfc9101.JWTAuthenticationRequest): + def get_request_object(self, request_uri: str): + try: + return requests.get(request_uri).text + except requests.Exception: + return None + """ + raise NotImplementedError() + + def resolve_client_public_key(self, client: ClientMixin): + """Resolve the client public key for verifying the JWT signature. + A client may have many public keys, in this case, we can retrieve it + via ``kid`` value in headers. Developers MUST implement this method:: + + from joserfc import KeySet + + + class JWTAuthenticationRequest(rfc9101.JWTAuthenticationRequest): + def resolve_client_public_key(self, client): + if client.jwks_uri: + data = requests.get(client.jwks_uri).json() + return KeySet.import_key_set(data) + + return KeySet.import_key_set(client.jwks) + """ + raise NotImplementedError() + + def get_server_metadata(self) -> dict: + """Return server metadata which includes supported grant types, + response types and etc. + + When the ``require_signed_request_object`` claim is :data:`True`, + all clients require that authorization requests + use request objects, and an error will be returned when the authorization + request payload is passed in the request body or query string:: + + class JWTAuthenticationRequest(rfc9101.JWTAuthenticationRequest): + def get_server_metadata(self): + return { + "issuer": ..., + "authorization_endpoint": ..., + "require_signed_request_object": ..., + "request_object_signing_alg_values_supported": ["RS256", ...], + } + + """ + return {} # pragma: no cover + + def get_client_require_signed_request_object(self, client: ClientMixin) -> bool: + """Return the 'require_signed_request_object' client metadata. + + When :data:`True`, the client requires that authorization requests + use request objects, and an error will be returned when the authorization + request payload is passed in the request body or query string:: + + class JWTAuthenticationRequest(rfc9101.JWTAuthenticationRequest): + def get_client_require_signed_request_object(self, client): + return client.require_signed_request_object + + If not implemented, the value is considered as :data:`False`. + """ + return False # pragma: no cover diff --git a/authlib/oauth2/rfc9101/discovery.py b/authlib/oauth2/rfc9101/discovery.py new file mode 100644 index 000000000..8468922ab --- /dev/null +++ b/authlib/oauth2/rfc9101/discovery.py @@ -0,0 +1,21 @@ +from authlib.oauth2.rfc8414.models import validate_boolean_value + + +class AuthorizationServerMetadata(dict): + """Authorization Server Metadata extension for RFC9101 (JAR). + + This class can be used with + :meth:`~authlib.oauth2.rfc8414.AuthorizationServerMetadata.validate` + to validate JAR-specific metadata:: + + from authlib.oauth2 import rfc8414, rfc9101 + + metadata = rfc8414.AuthorizationServerMetadata(data) + metadata.validate(metadata_classes=[rfc9101.AuthorizationServerMetadata]) + """ + + REGISTRY_KEYS = ["require_signed_request_object"] + + def validate_require_signed_request_object(self): + """Indicates where authorization request needs to be protected as Request Object and provided through either request or request_uri parameter.""" + validate_boolean_value(self, "require_signed_request_object") diff --git a/authlib/oauth2/rfc9101/errors.py b/authlib/oauth2/rfc9101/errors.py new file mode 100644 index 000000000..3feeeaabf --- /dev/null +++ b/authlib/oauth2/rfc9101/errors.py @@ -0,0 +1,34 @@ +from ..base import OAuth2Error + +__all__ = [ + "InvalidRequestUriError", + "InvalidRequestObjectError", + "RequestNotSupportedError", + "RequestUriNotSupportedError", +] + + +class InvalidRequestUriError(OAuth2Error): + error = "invalid_request_uri" + description = "The request_uri in the authorization request returns an error or contains invalid data." + status_code = 400 + + +class InvalidRequestObjectError(OAuth2Error): + error = "invalid_request_object" + description = "The request parameter contains an invalid Request Object." + status_code = 400 + + +class RequestNotSupportedError(OAuth2Error): + error = "request_not_supported" + description = ( + "The authorization server does not support the use of the request parameter." + ) + status_code = 400 + + +class RequestUriNotSupportedError(OAuth2Error): + error = "request_uri_not_supported" + description = "The authorization server does not support the use of the request_uri parameter." + status_code = 400 diff --git a/authlib/oauth2/rfc9101/registration.py b/authlib/oauth2/rfc9101/registration.py new file mode 100644 index 000000000..a8d3bab66 --- /dev/null +++ b/authlib/oauth2/rfc9101/registration.py @@ -0,0 +1,43 @@ +from joserfc.errors import InvalidClaimError + +from authlib.oauth2.claims import BaseClaims + + +class ClientMetadataClaims(BaseClaims): + """Additional client metadata can be used with :ref:`specs/rfc7591` and :ref:`specs/rfc7592` endpoints. + + This can be used with:: + + server.register_endpoint( + ClientRegistrationEndpoint( + claims_classes=[ + rfc7591.ClientMetadataClaims, + rfc9101.ClientMetadataClaims, + ] + ) + ) + + server.register_endpoint( + ClientRegistrationEndpoint( + claims_classes=[ + rfc7591.ClientMetadataClaims, + rfc9101.ClientMetadataClaims, + ] + ) + ) + + """ + + REGISTERED_CLAIMS = [ + "require_signed_request_object", + ] + + def validate(self, now=None, leeway=0): + super().validate(now, leeway) + self.validate_require_signed_request_object() + + def validate_require_signed_request_object(self): + self.setdefault("require_signed_request_object", False) + + if not isinstance(self["require_signed_request_object"], bool): + raise InvalidClaimError("require_signed_request_object") diff --git a/authlib/oauth2/rfc9207/__init__.py b/authlib/oauth2/rfc9207/__init__.py new file mode 100644 index 000000000..cdf7106db --- /dev/null +++ b/authlib/oauth2/rfc9207/__init__.py @@ -0,0 +1,4 @@ +from .discovery import AuthorizationServerMetadata +from .parameter import IssuerParameter + +__all__ = ["AuthorizationServerMetadata", "IssuerParameter"] diff --git a/authlib/oauth2/rfc9207/discovery.py b/authlib/oauth2/rfc9207/discovery.py new file mode 100644 index 000000000..f863772b4 --- /dev/null +++ b/authlib/oauth2/rfc9207/discovery.py @@ -0,0 +1,25 @@ +from authlib.oauth2.rfc8414.models import validate_boolean_value + + +class AuthorizationServerMetadata(dict): + """Authorization Server Metadata extension for RFC9207. + + This class can be used with + :meth:`~authlib.oauth2.rfc8414.AuthorizationServerMetadata.validate` + to validate RFC9207-specific metadata:: + + from authlib.oauth2 import rfc8414, rfc9207 + + metadata = rfc8414.AuthorizationServerMetadata(data) + metadata.validate(metadata_classes=[rfc9207.AuthorizationServerMetadata]) + """ + + REGISTRY_KEYS = ["authorization_response_iss_parameter_supported"] + + def validate_authorization_response_iss_parameter_supported(self): + """Boolean parameter indicating whether the authorization server + provides the iss parameter in the authorization response. + + If omitted, the default value is false. + """ + validate_boolean_value(self, "authorization_response_iss_parameter_supported") diff --git a/authlib/oauth2/rfc9207/parameter.py b/authlib/oauth2/rfc9207/parameter.py new file mode 100644 index 000000000..09e616a4e --- /dev/null +++ b/authlib/oauth2/rfc9207/parameter.py @@ -0,0 +1,43 @@ +from authlib.common.urls import add_params_to_uri +from authlib.deprecate import deprecate +from authlib.oauth2.rfc6749.grants import BaseGrant + + +class IssuerParameter: + def __call__(self, authorization_server): + if isinstance(authorization_server, BaseGrant): + deprecate( + "IssueParameter should be used as an authorization server extension with 'authorization_server.register_extension(IssueParameter())'.", + version="1.8", + ) + authorization_server.register_hook( + "after_authorization_response", + self.add_issuer_parameter, + ) + + else: + authorization_server.register_hook( + "after_create_authorization_response", + self.add_issuer_parameter, + ) + + def add_issuer_parameter(self, authorization_server, response): + if self.get_issuer() and response.location: + # RFC9207 §2 + # In authorization responses to the client, including error responses, + # an authorization server supporting this specification MUST indicate + # its identity by including the iss parameter in the response. + + new_location = add_params_to_uri( + response.location, {"iss": self.get_issuer()} + ) + response.location = new_location + + def get_issuer(self) -> str | None: + """Return the issuer URL. + Developers MAY implement this method if they want to support :rfc:`RFC9207 <9207>`:: + + def get_issuer(self) -> str: + return "https://auth.example.org" + """ + return None diff --git a/authlib/oidc/core/__init__.py b/authlib/oidc/core/__init__.py index 8ee628fa0..62649e020 100644 --- a/authlib/oidc/core/__init__.py +++ b/authlib/oidc/core/__init__.py @@ -1,23 +1,35 @@ -""" - authlib.oidc.core - ~~~~~~~~~~~~~~~~~ +"""authlib.oidc.core. +~~~~~~~~~~~~~~~~~ - OpenID Connect Core 1.0 Implementation. +OpenID Connect Core 1.0 Implementation. - http://openid.net/specs/openid-connect-core-1_0.html +http://openid.net/specs/openid-connect-core-1_0.html """ +from .claims import CodeIDToken +from .claims import HybridIDToken +from .claims import IDToken +from .claims import ImplicitIDToken +from .claims import UserInfo +from .claims import get_claim_cls_by_response_type +from .grants import OpenIDCode +from .grants import OpenIDHybridGrant +from .grants import OpenIDImplicitGrant +from .grants import OpenIDToken from .models import AuthorizationCodeMixin -from .claims import ( - IDToken, CodeIDToken, ImplicitIDToken, HybridIDToken, - UserInfo, get_claim_cls_by_response_type, -) -from .grants import OpenIDCode, OpenIDHybridGrant, OpenIDImplicitGrant - +from .userinfo import UserInfoEndpoint __all__ = [ - 'AuthorizationCodeMixin', - 'IDToken', 'CodeIDToken', 'ImplicitIDToken', 'HybridIDToken', - 'UserInfo', 'get_claim_cls_by_response_type', - 'OpenIDCode', 'OpenIDHybridGrant', 'OpenIDImplicitGrant', + "AuthorizationCodeMixin", + "IDToken", + "CodeIDToken", + "ImplicitIDToken", + "HybridIDToken", + "UserInfo", + "UserInfoEndpoint", + "get_claim_cls_by_response_type", + "OpenIDToken", + "OpenIDCode", + "OpenIDHybridGrant", + "OpenIDImplicitGrant", ] diff --git a/authlib/oidc/core/claims.py b/authlib/oidc/core/claims.py index dc3a84301..606a9d904 100644 --- a/authlib/oidc/core/claims.py +++ b/authlib/oidc/core/claims.py @@ -1,46 +1,52 @@ -import time +from __future__ import annotations + import hmac + +from joserfc.errors import InvalidClaimError +from joserfc.errors import MissingClaimError + from authlib.common.encoding import to_bytes -from authlib.jose import JWTClaims -from authlib.jose.errors import ( - MissingClaimError, - InvalidClaimError, -) +from authlib.oauth2.claims import JWTClaims +from authlib.oauth2.rfc6749.util import scope_to_list + from .util import create_half_hash __all__ = [ - 'IDToken', 'CodeIDToken', 'ImplicitIDToken', 'HybridIDToken', - 'UserInfo', 'get_claim_cls_by_response_type' + "IDToken", + "CodeIDToken", + "ImplicitIDToken", + "HybridIDToken", + "UserInfo", + "get_claim_cls_by_response_type", ] _REGISTERED_CLAIMS = [ - 'iss', 'sub', 'aud', 'exp', 'nbf', 'iat', - 'auth_time', 'nonce', 'acr', 'amr', 'azp', - 'at_hash', + "iss", + "sub", + "aud", + "exp", + "nbf", + "iat", + "auth_time", + "nonce", + "acr", + "amr", + "azp", + "at_hash", ] class IDToken(JWTClaims): - ESSENTIAL_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'iat'] + ESSENTIAL_CLAIMS = ["iss", "sub", "aud", "exp", "iat"] def validate(self, now=None, leeway=0): for k in self.ESSENTIAL_CLAIMS: if k not in self: raise MissingClaimError(k) - self._validate_essential_claims() - if now is None: - now = int(time.time()) - - self.validate_iss() - self.validate_sub() - self.validate_aud() - self.validate_exp(now, leeway) - self.validate_nbf(now, leeway) - self.validate_iat(now, leeway) + super().validate(now, leeway) self.validate_auth_time() self.validate_nonce() - self.validate_acr() self.validate_amr() self.validate_azp() self.validate_at_hash() @@ -52,12 +58,12 @@ def validate_auth_time(self): when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL. """ - auth_time = self.get('auth_time') - if self.params.get('max_age') and not auth_time: - raise MissingClaimError('auth_time') + auth_time = self.get("auth_time") + if self.params.get("max_age") and not auth_time: + raise MissingClaimError("auth_time") - if auth_time and not isinstance(auth_time, int): - raise InvalidClaimError('auth_time') + if auth_time and not isinstance(auth_time, (int, float)): + raise InvalidClaimError("auth_time") def validate_nonce(self): """String value used to associate a Client session with an ID Token, @@ -71,32 +77,12 @@ def validate_nonce(self): SHOULD perform no other processing on nonce values used. The nonce value is a case sensitive string. """ - nonce_value = self.params.get('nonce') + nonce_value = self.params.get("nonce") if nonce_value: - if 'nonce' not in self: - raise MissingClaimError('nonce') - if nonce_value != self['nonce']: - raise InvalidClaimError('nonce') - - def validate_acr(self): - """OPTIONAL. Authentication Context Class Reference. String specifying - an Authentication Context Class Reference value that identifies the - Authentication Context Class that the authentication performed - satisfied. The value "0" indicates the End-User authentication did not - meet the requirements of `ISO/IEC 29115`_ level 1. Authentication - using a long-lived browser cookie, for instance, is one example where - the use of "level 0" is appropriate. Authentications with level 0 - SHOULD NOT be used to authorize access to any resource of any monetary - value. An absolute URI or an `RFC 6711`_ registered name SHOULD be - used as the acr value; registered names MUST NOT be used with a - different meaning than that which is registered. Parties using this - claim will need to agree upon the meanings of the values used, which - may be context-specific. The acr value is a case sensitive string. - - .. _`ISO/IEC 29115`: https://www.iso.org/standard/45138.html - .. _`RFC 6711`: https://tools.ietf.org/html/rfc6711 - """ - return self._validate_claim_value('acr') + if "nonce" not in self: + raise MissingClaimError("nonce") + if nonce_value != self["nonce"]: + raise InvalidClaimError("nonce") def validate_amr(self): """OPTIONAL. Authentication Methods References. JSON array of strings @@ -108,9 +94,9 @@ def validate_amr(self): meanings of the values used, which may be context-specific. The amr value is an array of case sensitive strings. """ - amr = self.get('amr') - if amr and not isinstance(self['amr'], list): - raise InvalidClaimError('amr') + amr = self.get("amr") + if amr and not isinstance(self["amr"], list): + raise InvalidClaimError("amr") def validate_azp(self): """OPTIONAL. Authorized party - the party to which the ID Token was @@ -121,8 +107,8 @@ def validate_azp(self): as the sole audience. The azp value is a case sensitive string containing a StringOrURI value. """ - aud = self.get('aud') - client_id = self.params.get('client_id') + aud = self.get("aud") + client_id = self.params.get("client_id") required = False if aud and client_id: if isinstance(aud, list) and len(aud) == 1: @@ -130,12 +116,12 @@ def validate_azp(self): if aud != client_id: required = True - azp = self.get('azp') + azp = self.get("azp") if required and not azp: - raise MissingClaimError('azp') + raise MissingClaimError("azp") if azp and client_id and azp != client_id: - raise InvalidClaimError('azp') + raise InvalidClaimError("azp") def validate_at_hash(self): """OPTIONAL. Access Token hash value. Its value is the base64url @@ -146,21 +132,21 @@ def validate_at_hash(self): access_token value with SHA-256, then take the left-most 128 bits and base64url encode them. The at_hash value is a case sensitive string. """ - access_token = self.params.get('access_token') - at_hash = self.get('at_hash') + access_token = self.params.get("access_token") + at_hash = self.get("at_hash") if at_hash and access_token: - if not _verify_hash(at_hash, access_token, self.header['alg']): - raise InvalidClaimError('at_hash') + if not _verify_hash(at_hash, access_token, self.header["alg"]): + raise InvalidClaimError("at_hash") class CodeIDToken(IDToken): - RESPONSE_TYPES = ('code',) + RESPONSE_TYPES = ("code",) REGISTERED_CLAIMS = _REGISTERED_CLAIMS class ImplicitIDToken(IDToken): - RESPONSE_TYPES = ('id_token', 'id_token token') - ESSENTIAL_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'iat', 'nonce'] + RESPONSE_TYPES = ("id_token", "id_token token") + ESSENTIAL_CLAIMS = ["iss", "sub", "aud", "exp", "iat", "nonce"] REGISTERED_CLAIMS = _REGISTERED_CLAIMS def validate_at_hash(self): @@ -170,18 +156,18 @@ def validate_at_hash(self): Token is issued, which is the case for the response_type value id_token. """ - access_token = self.params.get('access_token') - if access_token and 'at_hash' not in self: - raise MissingClaimError('at_hash') - super(ImplicitIDToken, self).validate_at_hash() + access_token = self.params.get("access_token") + if access_token and "at_hash" not in self: + raise MissingClaimError("at_hash") + super().validate_at_hash() class HybridIDToken(ImplicitIDToken): - RESPONSE_TYPES = ('code id_token', 'code token', 'code id_token token') - REGISTERED_CLAIMS = _REGISTERED_CLAIMS + ['c_hash'] + RESPONSE_TYPES = ("code id_token", "code token", "code id_token token") + REGISTERED_CLAIMS = _REGISTERED_CLAIMS + ["c_hash"] def validate(self, now=None, leeway=0): - super(HybridIDToken, self).validate(now=now, leeway=leeway) + super().validate(now=now, leeway=leeway) self.validate_c_hash() def validate_c_hash(self): @@ -196,13 +182,13 @@ def validate_c_hash(self): which is the case for the response_type values code id_token and code id_token token, this is REQUIRED; otherwise, its inclusion is OPTIONAL. """ - code = self.params.get('code') - c_hash = self.get('c_hash') + code = self.params.get("code") + c_hash = self.get("c_hash") if code: if not c_hash: - raise MissingClaimError('c_hash') - if not _verify_hash(c_hash, code, self.header['alg']): - raise InvalidClaimError('c_hash') + raise MissingClaimError("c_hash") + if not _verify_hash(c_hash, code, self.header["alg"]): + raise InvalidClaimError("c_hash") class UserInfo(dict): @@ -213,12 +199,64 @@ class UserInfo(dict): #: registered claims that UserInfo supports REGISTERED_CLAIMS = [ - 'sub', 'name', 'given_name', 'family_name', 'middle_name', 'nickname', - 'preferred_username', 'profile', 'picture', 'website', 'email', - 'email_verified', 'gender', 'birthdate', 'zoneinfo', 'locale', - 'phone_number', 'phone_number_verified', 'address', 'updated_at', + "sub", + "name", + "given_name", + "family_name", + "middle_name", + "nickname", + "preferred_username", + "profile", + "picture", + "website", + "email", + "email_verified", + "gender", + "birthdate", + "zoneinfo", + "locale", + "phone_number", + "phone_number_verified", + "address", + "updated_at", ] + SCOPES_CLAIMS_MAPPING = { + "openid": ["sub"], + "profile": [ + "name", + "family_name", + "given_name", + "middle_name", + "nickname", + "preferred_username", + "profile", + "picture", + "website", + "gender", + "birthdate", + "zoneinfo", + "locale", + "updated_at", + ], + "email": ["email", "email_verified"], + "address": ["address"], + "phone": ["phone_number", "phone_number_verified"], + } + + def filter(self, scope: str): + """Return a new UserInfo object containing only the claims matching the scope passed in parameter.""" + scope = scope_to_list(scope) + filtered_claims = [ + claim + for scope_part in scope + for claim in self.SCOPES_CLAIMS_MAPPING.get(scope_part, []) + ] + filtered_items = { + key: val for key, val in self.items() if key in filtered_claims + } + return UserInfo(filtered_items) + def __getattr__(self, key): try: return object.__getattribute__(self, key) @@ -237,6 +275,6 @@ def get_claim_cls_by_response_type(response_type): def _verify_hash(signature, s, alg): hash_value = create_half_hash(s, alg) - if not hash_value: - return True + if hash_value is None: + return False return hmac.compare_digest(hash_value, to_bytes(signature)) diff --git a/authlib/oidc/core/errors.py b/authlib/oidc/core/errors.py index e5fb630e3..a2ed7609c 100644 --- a/authlib/oidc/core/errors.py +++ b/authlib/oidc/core/errors.py @@ -10,7 +10,8 @@ class InteractionRequiredError(OAuth2Error): http://openid.net/specs/openid-connect-core-1_0.html#AuthError """ - error = 'interaction_required' + + error = "interaction_required" class LoginRequiredError(OAuth2Error): @@ -21,7 +22,8 @@ class LoginRequiredError(OAuth2Error): http://openid.net/specs/openid-connect-core-1_0.html#AuthError """ - error = 'login_required' + + error = "login_required" class AccountSelectionRequiredError(OAuth2Error): @@ -35,7 +37,8 @@ class AccountSelectionRequiredError(OAuth2Error): http://openid.net/specs/openid-connect-core-1_0.html#AuthError """ - error = 'account_selection_required' + + error = "account_selection_required" class ConsentRequiredError(OAuth2Error): @@ -46,7 +49,8 @@ class ConsentRequiredError(OAuth2Error): http://openid.net/specs/openid-connect-core-1_0.html#AuthError """ - error = 'consent_required' + + error = "consent_required" class InvalidRequestURIError(OAuth2Error): @@ -55,24 +59,29 @@ class InvalidRequestURIError(OAuth2Error): http://openid.net/specs/openid-connect-core-1_0.html#AuthError """ - error = 'invalid_request_uri' + + error = "invalid_request_uri" class InvalidRequestObjectError(OAuth2Error): """The request parameter contains an invalid Request Object.""" - error = 'invalid_request_object' + + error = "invalid_request_object" class RequestNotSupportedError(OAuth2Error): """The OP does not support use of the request parameter.""" - error = 'request_not_supported' + + error = "request_not_supported" class RequestURINotSupportedError(OAuth2Error): """The OP does not support use of the request_uri parameter.""" - error = 'request_uri_not_supported' + + error = "request_uri_not_supported" class RegistrationNotSupportedError(OAuth2Error): """The OP does not support use of the registration parameter.""" - error = 'registration_not_supported' + + error = "registration_not_supported" diff --git a/authlib/oidc/core/grants/__init__.py b/authlib/oidc/core/grants/__init__.py index fb60bb72c..d01ac083c 100644 --- a/authlib/oidc/core/grants/__init__.py +++ b/authlib/oidc/core/grants/__init__.py @@ -1,9 +1,11 @@ from .code import OpenIDCode -from .implicit import OpenIDImplicitGrant +from .code import OpenIDToken from .hybrid import OpenIDHybridGrant +from .implicit import OpenIDImplicitGrant __all__ = [ - 'OpenIDCode', - 'OpenIDImplicitGrant', - 'OpenIDHybridGrant', + "OpenIDToken", + "OpenIDCode", + "OpenIDImplicitGrant", + "OpenIDHybridGrant", ] diff --git a/authlib/oidc/core/grants/_legacy.py b/authlib/oidc/core/grants/_legacy.py new file mode 100644 index 000000000..1001554d8 --- /dev/null +++ b/authlib/oidc/core/grants/_legacy.py @@ -0,0 +1,103 @@ +import time +import warnings + +from authlib.oauth2 import OAuth2Request + + +class LegacyMixin: + DEFAULT_EXPIRES_IN = 3600 + + def resolve_client_private_key(self, client): + """Resolve the client private key for encoding ``id_token`` Developers + MUST implement this method in subclass, e.g.:: + + import json + from joserfc.jwk import KeySet + + + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + """ + config = self._compatible_resolve_jwt_config(None, client) + return config["key"] + + def get_client_algorithm(self, client): + """Return the algorithm for encoding ``id_token``. By default, it will + use ``client.id_token_signed_response_alg``, if not defined, ``RS256`` + will be used. But you can override this method to customize the returned + algorithm. + """ + # Per OpenID Connect Registration 1.0 Section 2: + # Use client's id_token_signed_response_alg if specified + config = self._compatible_resolve_jwt_config(None, client) + alg = config.get("alg") + if alg: + return alg + + if hasattr(client, "id_token_signed_response_alg"): + return client.id_token_signed_response_alg or "RS256" + return "RS256" + + def get_client_claims(self, client): + """Return the default client claims for encoding the ``id_token``. Developers + MUST implement this method in subclass, e.g.:: + + def get_client_claims(self, client): + return { + "iss": "your-service-url", + "aud": [client.get_client_id()], + } + """ + config = self._compatible_resolve_jwt_config(None, client) + claims = {k: config[k] for k in config if k not in ["key", "alg"]} + if "exp" in config: + now = int(time.time()) + claims["exp"] = now + config["exp"] + return claims + + def get_encode_header(self, client): + config = self._compatible_resolve_jwt_config(None, client) + kid = config.get("kid") + header = {"alg": self.get_client_algorithm(client)} + if kid: + header["kid"] = kid + return header + + def get_compatible_claims(self, request: OAuth2Request): + now = int(time.time()) + + claims = self.get_client_claims(request.client) + claims.setdefault("iat", now) + claims.setdefault("exp", now + self.DEFAULT_EXPIRES_IN) + claims.setdefault("auth_time", now) + + # compatible code + if "aud" not in claims and hasattr(self, "get_audiences"): + warnings.warn( + "get_audiences(self, request) is deprecated and will be removed in version 1.8. " + "You can set the ``aud`` value in get_client_claims instead.", + DeprecationWarning, + stacklevel=2, + ) + claims["aud"] = self.get_audiences(request) + + claims.setdefault("aud", [request.client.get_client_id()]) + return claims + + def _compatible_resolve_jwt_config(self, grant, client): + if not hasattr(self, "get_jwt_config"): + return {} + + warnings.warn( + "get_jwt_config(self, grant) is deprecated and will be removed in version 1.8. " + "Use resolve_client_private_key, get_client_claims, get_client_algorithm instead.", + DeprecationWarning, + stacklevel=2, + ) + try: + config = self.get_jwt_config(grant, client) + except TypeError: + config = self.get_jwt_config(grant) + return config diff --git a/authlib/oidc/core/grants/code.py b/authlib/oidc/core/grants/code.py index 7fb3265ac..c482590f6 100644 --- a/authlib/oidc/core/grants/code.py +++ b/authlib/oidc/core/grants/code.py @@ -1,74 +1,54 @@ -""" - authlib.oidc.core.grants.code - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oidc.core.grants.code. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Implementation of Authentication using the Authorization Code Flow - per `Section 3.1`_. +Implementation of Authentication using the Authorization Code Flow +per `Section 3.1`_. - .. _`Section 3.1`: http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth +.. _`Section 3.1`: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth """ import logging -from .util import ( - is_openid_scope, - validate_nonce, - validate_request_prompt, - generate_id_token, -) -log = logging.getLogger(__name__) +from joserfc import jwt +from authlib._joserfc_helpers import import_any_key +from authlib.oauth2.rfc6749 import OAuth2Request -class OpenIDCode(object): - """An extension from OpenID Connect for "grant_type=code" request. - """ - def __init__(self, require_nonce=False): - self.require_nonce = require_nonce +from ..models import AuthorizationCodeMixin +from ._legacy import LegacyMixin +from .util import create_half_hash +from .util import is_openid_scope +from .util import validate_nonce +from .util import validate_request_prompt - def exists_nonce(self, nonce, request): - """Check if the given nonce is existing in your database. Developers - MUST implement this method in subclass, e.g.:: - - def exists_nonce(self, nonce, request): - exists = AuthorizationCode.query.filter_by( - client_id=request.client_id, nonce=nonce - ).first() - return bool(exists) +log = logging.getLogger(__name__) - :param nonce: A string of "nonce" parameter in request - :param request: OAuth2Request instance - :return: Boolean - """ - raise NotImplementedError() - def get_jwt_config(self, grant): # pragma: no cover - """Get the JWT configuration for OpenIDCode extension. The JWT - configuration will be used to generate ``id_token``. Developers - MUST implement this method in subclass, e.g.:: +class OpenIDToken(LegacyMixin): + def get_authorization_code_claims(self, authorization_code: AuthorizationCodeMixin): + claims = { + "nonce": authorization_code.get_nonce(), + "auth_time": authorization_code.get_auth_time(), + } - def get_jwt_config(self, grant): - return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', - 'iss': 'issuer-identity', - 'exp': 3600 - } + if acr := authorization_code.get_acr(): + claims["acr"] = acr - :param grant: AuthorizationCodeGrant instance - :return: dict - """ - raise NotImplementedError() + if amr := authorization_code.get_amr(): + claims["amr"] = amr + return claims - def generate_user_info(self, user, scope): # pragma: no cover + def generate_user_info(self, user, scope): """Provide user information for the given scope. Developers MUST implement this method in subclass, e.g.:: from authlib.oidc.core import UserInfo + def generate_user_info(self, user, scope): user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email + if "email" in scope: + user_info["email"] = user.email return user_info :param user: user instance @@ -77,43 +57,101 @@ def generate_user_info(self, user, scope): """ raise NotImplementedError() - def get_audiences(self, request): - """Parse `aud` value for id_token, default value is client id. Developers - MAY rewrite this method to provide a customized audience value. - """ - client = request.client - return [client.get_client_id()] + def encode_id_token(self, token, request: OAuth2Request): + alg = self.get_client_algorithm(request.client) + header = self.get_encode_header(request.client) + + claims = self.get_compatible_claims(request) + if request.authorization_code: + claims.update( + self.get_authorization_code_claims(request.authorization_code) + ) + + access_token = token.get("access_token") + if access_token: + at_hash = create_half_hash(access_token, alg) + if at_hash is not None: + claims["at_hash"] = at_hash.decode("utf-8") + + user_info = self.generate_user_info(request.user, token["scope"]) + claims.update(user_info) + + if alg == "none": + private_key = None + else: + key = self.resolve_client_private_key(request.client) + private_key = import_any_key(key) + + return jwt.encode(header, claims, private_key, [alg]) - def process_token(self, grant, token): - scope = token.get('scope') + def process_token(self, grant, response): + _, token, _ = response + scope = token.get("scope") if not scope or not is_openid_scope(scope): # standard authorization code flow return token - request = grant.request - credential = request.credential + request: OAuth2Request = grant.request + id_token = self.encode_id_token(token, request) + token["id_token"] = id_token + return token + + def __call__(self, grant): + grant.register_hook("after_create_token_response", self.process_token) - config = self.get_jwt_config(grant) - config['aud'] = self.get_audiences(request) - config['nonce'] = credential.get_nonce() - config['auth_time'] = credential.get_auth_time() - user_info = self.generate_user_info(request.user, token['scope']) - id_token = generate_id_token(token, user_info, **config) - token['id_token'] = id_token - return token +class OpenIDCode(OpenIDToken): + """An extension from OpenID Connect for "grant_type=code" request. Developers + MUST implement the missing methods:: + + class MyOpenIDCode(OpenIDCode): + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + + def exists_nonce(self, nonce, request): + return check_if_nonce_in_cache(request.payload.client_id, nonce) + + def generate_user_info(self, user, scope): + return {...} + + The register this extension with AuthorizationCodeGrant:: + + authorization_server.register_grant( + AuthorizationCodeGrant, extensions=[MyOpenIDCode()] + ) + """ + + def __init__(self, require_nonce=False): + self.require_nonce = require_nonce + + def exists_nonce(self, nonce, request): + """Check if the given nonce is existing in your database. Developers + MUST implement this method in subclass, e.g.:: + + def exists_nonce(self, nonce, request): + exists = AuthorizationCode.query.filter_by( + client_id=request.payload.client_id, nonce=nonce + ).first() + return bool(exists) + + :param nonce: A string of "nonce" parameter in request + :param request: OAuth2Request instance + :return: Boolean + """ + raise NotImplementedError() - def validate_openid_authorization_request(self, grant): + def validate_openid_authorization_request(self, grant, redirect_uri): validate_nonce(grant.request, self.exists_nonce, self.require_nonce) def __call__(self, grant): - grant.register_hook('process_token', self.process_token) - if is_openid_scope(grant.request.scope): + grant.register_hook("after_create_token_response", self.process_token) + if is_openid_scope(grant.request.payload.scope): grant.register_hook( - 'after_validate_authorization_request', - self.validate_openid_authorization_request + "after_validate_authorization_request_payload", + self.validate_openid_authorization_request, ) grant.register_hook( - 'after_validate_consent_request', - validate_request_prompt + "after_validate_consent_request", validate_request_prompt ) diff --git a/authlib/oidc/core/grants/hybrid.py b/authlib/oidc/core/grants/hybrid.py index d2c14acf6..8c373525b 100644 --- a/authlib/oidc/core/grants/hybrid.py +++ b/authlib/oidc/core/grants/hybrid.py @@ -1,12 +1,14 @@ import logging -from authlib.deprecate import deprecate + from authlib.common.security import generate_token from authlib.oauth2.rfc6749 import InvalidScopeError from authlib.oauth2.rfc6749.grants.authorization_code import ( - validate_code_authorization_request + validate_code_authorization_request, ) + from .implicit import OpenIDImplicitGrant -from .util import is_openid_scope, validate_nonce +from .util import is_openid_scope +from .util import validate_nonce log = logging.getLogger(__name__) @@ -15,12 +17,12 @@ class OpenIDHybridGrant(OpenIDImplicitGrant): #: Generated "code" length AUTHORIZATION_CODE_LENGTH = 48 - RESPONSE_TYPES = {'code id_token', 'code token', 'code id_token token'} - GRANT_TYPE = 'code' - DEFAULT_RESPONSE_MODE = 'fragment' + RESPONSE_TYPES = {"code id_token", "code token", "code id_token token"} + GRANT_TYPE = "code" + DEFAULT_RESPONSE_MODE = "fragment" def generate_authorization_code(self): - """"The method to generate "code" value for authorization code data. + """ "The method to generate "code" value for authorization code data. Developers may rewrite this method, or customize the code length with:: class MyAuthorizationCodeGrant(AuthorizationCodeGrant): @@ -34,63 +36,55 @@ def save_authorization_code(self, code, request): def save_authorization_code(self, code, request): client = request.client - item = AuthorizationCode( + auth_code = AuthorizationCode( code=code, client_id=client.client_id, - redirect_uri=request.redirect_uri, - scope=request.scope, - nonce=request.data.get('nonce'), + redirect_uri=request.payload.redirect_uri, + scope=request.payload.scope, + nonce=request.payload.data.get("nonce"), user_id=request.user.id, ) - item.save() + auth_code.save() """ raise NotImplementedError() def validate_authorization_request(self): - if not is_openid_scope(self.request.scope): + if not is_openid_scope(self.request.payload.scope): raise InvalidScopeError( - 'Missing "openid" scope', - redirect_uri=self.request.redirect_uri, + "Missing 'openid' scope", + redirect_uri=self.request.payload.redirect_uri, redirect_fragment=True, ) self.register_hook( - 'after_validate_authorization_request', - lambda grant: validate_nonce( - grant.request, grant.exists_nonce, required=True) + "after_validate_authorization_request_payload", + lambda grant, redirect_uri: validate_nonce( + grant.request, grant.exists_nonce, required=True + ), ) return validate_code_authorization_request(self) def create_granted_params(self, grant_user): self.request.user = grant_user client = self.request.client - - if hasattr(self, 'create_authorization_code'): # pragma: no cover - deprecate('Use "generate_authorization_code" instead', '1.0') - code = self.create_authorization_code(client, grant_user, self.request) - else: - code = self.generate_authorization_code() - self.save_authorization_code(code, self.request) - - params = [('code', code)] + code = self.generate_authorization_code() + self.save_authorization_code(code, self.request) + params = [("code", code)] token = self.generate_token( - grant_type='implicit', + grant_type="implicit", user=grant_user, - scope=self.request.scope, - include_refresh_token=False + scope=self.request.payload.scope, + include_refresh_token=False, ) - response_types = self.request.response_type.split() - if 'token' in response_types: - log.debug('Grant token %r to %r', token, client) + response_types = self.request.payload.response_type.split() + if "token" in response_types: + log.debug("Grant token %r to %r", token, client) self.server.save_token(token, self.request) - if 'id_token' in response_types: + if "id_token" in response_types: token = self.process_implicit_token(token, code) else: # response_type is "code id_token" - token = { - 'expires_in': token['expires_in'], - 'scope': token['scope'] - } + token = {"expires_in": token["expires_in"], "scope": token["scope"]} token = self.process_implicit_token(token, code) params.extend([(k, token[k]) for k in token]) diff --git a/authlib/oidc/core/grants/implicit.py b/authlib/oidc/core/grants/implicit.py index ac1a76313..c8a09bc3c 100644 --- a/authlib/oidc/core/grants/implicit.py +++ b/authlib/oidc/core/grants/implicit.py @@ -1,24 +1,29 @@ import logging -from authlib.oauth2.rfc6749 import ( - OAuth2Error, - InvalidScopeError, - AccessDeniedError, - ImplicitGrant, -) -from .util import ( - is_openid_scope, - validate_nonce, - validate_request_prompt, - create_response_mode_response, - generate_id_token, -) +import warnings + +from joserfc import jwt + +from authlib._joserfc_helpers import import_any_key +from authlib.oauth2.rfc6749 import AccessDeniedError +from authlib.oauth2.rfc6749 import ImplicitGrant +from authlib.oauth2.rfc6749 import InvalidScopeError +from authlib.oauth2.rfc6749 import OAuth2Error +from authlib.oauth2.rfc6749.errors import InvalidRequestError +from authlib.oauth2.rfc6749.hooks import hooked + +from ._legacy import LegacyMixin +from .util import create_half_hash +from .util import create_response_mode_response +from .util import is_openid_scope +from .util import validate_nonce +from .util import validate_request_prompt log = logging.getLogger(__name__) -class OpenIDImplicitGrant(ImplicitGrant): - RESPONSE_TYPES = {'id_token token', 'id_token'} - DEFAULT_RESPONSE_MODE = 'fragment' +class OpenIDImplicitGrant(LegacyMixin, ImplicitGrant): + RESPONSE_TYPES = {"id_token token", "id_token"} + DEFAULT_RESPONSE_MODE = "fragment" def exists_nonce(self, nonce, request): """Check if the given nonce is existing in your database. Developers @@ -26,7 +31,7 @@ def exists_nonce(self, nonce, request): def exists_nonce(self, nonce, request): exists = AuthorizationCode.query.filter_by( - client_id=request.client_id, nonce=nonce + client_id=request.payload.client_id, nonce=nonce ).first() return bool(exists) @@ -36,33 +41,17 @@ def exists_nonce(self, nonce, request): """ raise NotImplementedError() - def get_jwt_config(self): - """Get the JWT configuration for OpenIDImplicitGrant. The JWT - configuration will be used to generate ``id_token``. Developers - MUST implement this method in subclass, e.g.:: - - def get_jwt_config(self): - return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', - 'iss': 'issuer-identity', - 'exp': 3600 - } - - :return: dict - """ - raise NotImplementedError() - def generate_user_info(self, user, scope): """Provide user information for the given scope. Developers MUST implement this method in subclass, e.g.:: from authlib.oidc.core import UserInfo + def generate_user_info(self, user, scope): user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email + if "email" in scope: + user_info["email"] = user.email return user_info :param user: user instance @@ -79,14 +68,13 @@ def get_audiences(self, request): return [client.get_client_id()] def validate_authorization_request(self): - if not is_openid_scope(self.request.scope): + if not is_openid_scope(self.request.payload.scope): raise InvalidScopeError( - 'Missing "openid" scope', - redirect_uri=self.request.redirect_uri, + "Missing 'openid' scope", + redirect_uri=self.request.payload.redirect_uri, redirect_fragment=True, ) - redirect_uri = super( - OpenIDImplicitGrant, self).validate_authorization_request() + redirect_uri = super().validate_authorization_request() try: validate_nonce(self.request, self.exists_nonce, required=True) except OAuth2Error as error: @@ -95,22 +83,26 @@ def validate_authorization_request(self): raise error return redirect_uri + @hooked def validate_consent_request(self): redirect_uri = self.validate_authorization_request() validate_request_prompt(self, redirect_uri, redirect_fragment=True) + return redirect_uri def create_authorization_response(self, redirect_uri, grant_user): - state = self.request.state + state = self.request.payload.state if grant_user: params = self.create_granted_params(grant_user) if state: - params.append(('state', state)) + params.append(("state", state)) else: - error = AccessDeniedError(state=state) + error = AccessDeniedError() params = error.get_body() # http://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#ResponseModes - response_mode = self.request.data.get('response_mode', self.DEFAULT_RESPONSE_MODE) + response_mode = self.request.payload.data.get( + "response_mode", self.DEFAULT_RESPONSE_MODE + ) return create_response_mode_response( redirect_uri=redirect_uri, params=params, @@ -122,30 +114,73 @@ def create_granted_params(self, grant_user): client = self.request.client token = self.generate_token( user=grant_user, - scope=self.request.scope, - include_refresh_token=False + scope=self.request.payload.scope, + include_refresh_token=False, ) - if self.request.response_type == 'id_token': + if self.request.payload.response_type == "id_token": token = { - 'expires_in': token['expires_in'], - 'scope': token['scope'], + "expires_in": token["expires_in"], + "scope": token["scope"], } token = self.process_implicit_token(token) else: - log.debug('Grant token %r to %r', token, client) + log.debug("Grant token %r to %r", token, client) self.server.save_token(token, self.request) token = self.process_implicit_token(token) params = [(k, token[k]) for k in token] return params def process_implicit_token(self, token, code=None): - config = self.get_jwt_config() - config['aud'] = self.get_audiences(self.request) - config['nonce'] = self.request.data.get('nonce') - if code is not None: - config['code'] = code + alg = self.get_client_algorithm(self.request.client) + if alg == "none": + # According to oidc-registration §2 the 'none' alg is not valid in + # implicit flows: + # The value none MUST NOT be used as the ID Token alg value unless + # the Client uses only Response Types that return no ID Token from + # the Authorization Endpoint (such as when only using the + # Authorization Code Flow). + raise InvalidRequestError( + "id_token must be signed in implicit flows", + redirect_uri=self.request.payload.redirect_uri, + redirect_fragment=True, + ) - user_info = self.generate_user_info(self.request.user, token['scope']) - id_token = generate_id_token(token, user_info, **config) - token['id_token'] = id_token + claims = self.get_compatible_claims(self.request) + nonce = self.request.payload.data.get("nonce") + if nonce: + claims["nonce"] = nonce + + if code is not None: + c_hash = create_half_hash(code, alg) + if c_hash is not None: + claims["c_hash"] = c_hash.decode("utf-8") + + access_token = token.get("access_token") + if access_token: + at_hash = create_half_hash(access_token, alg) + if at_hash is not None: + claims["at_hash"] = at_hash.decode("utf-8") + + user_info = self.generate_user_info(self.request.user, token["scope"]) + claims.update(user_info) + key = self.resolve_client_private_key(self.request.client) + private_key = import_any_key(key) + header = self.get_encode_header(self.request.client) + id_token = jwt.encode(header, claims, private_key, [alg]) + token["id_token"] = id_token return token + + def _compatible_resolve_jwt_config(self, grant, client): + if not hasattr(self, "get_jwt_config"): + return {} + warnings.warn( + "get_jwt_config(self, client) is deprecated and will be removed in version 1.8. " + "Use resolve_client_private_key, get_client_claims, get_client_algorithm instead.", + DeprecationWarning, + stacklevel=2, + ) + try: + config = self.get_jwt_config(client) + except TypeError: + config = self.get_jwt_config() + return config diff --git a/authlib/oidc/core/grants/util.py b/authlib/oidc/core/grants/util.py index a83a2c948..a228fdaab 100644 --- a/authlib/oidc/core/grants/util.py +++ b/authlib/oidc/core/grants/util.py @@ -1,155 +1,167 @@ import time -import random -from authlib.oauth2.rfc6749 import InvalidRequestError -from authlib.oauth2.rfc6749.util import scope_to_list -from authlib.jose import JWT + +from joserfc import jwt + +from authlib._joserfc_helpers import import_any_key from authlib.common.encoding import to_native -from authlib.common.urls import add_params_to_uri, quote_url +from authlib.common.urls import add_params_to_uri +from authlib.common.urls import quote_url +from authlib.oauth2.rfc6749 import InvalidRequestError +from authlib.oauth2.rfc6749 import scope_to_list + +from ..errors import AccountSelectionRequiredError +from ..errors import ConsentRequiredError +from ..errors import LoginRequiredError from ..util import create_half_hash -from ..errors import ( - LoginRequiredError, - AccountSelectionRequiredError, - ConsentRequiredError, -) def is_openid_scope(scope): scopes = scope_to_list(scope) - return scopes and 'openid' in scopes + return scopes and "openid" in scopes def validate_request_prompt(grant, redirect_uri, redirect_fragment=False): - prompt = grant.request.data.get('prompt') + prompt = grant.request.payload.data.get("prompt") end_user = grant.request.user if not prompt: if not end_user: - grant.prompt = 'login' + grant.prompt = "login" return grant - if prompt == 'none' and not end_user: + if prompt == "none" and not end_user: raise LoginRequiredError( - redirect_uri=redirect_uri, - redirect_fragment=redirect_fragment) + redirect_uri=redirect_uri, redirect_fragment=redirect_fragment + ) prompts = prompt.split() - if 'none' in prompts and len(prompts) > 1: + if "none" in prompts and len(prompts) > 1: # If this parameter contains none with any other value, # an error is returned raise InvalidRequestError( - 'Invalid "prompt" parameter.', + "Invalid 'prompt' parameter.", redirect_uri=redirect_uri, - redirect_fragment=redirect_fragment) + redirect_fragment=redirect_fragment, + ) prompt = _guess_prompt_value( - end_user, prompts, redirect_uri, redirect_fragment=redirect_fragment) + end_user, prompts, redirect_uri, redirect_fragment=redirect_fragment + ) if prompt: grant.prompt = prompt return grant def validate_nonce(request, exists_nonce, required=False): - nonce = request.data.get('nonce') + nonce = request.payload.data.get("nonce") if not nonce: if required: - raise InvalidRequestError('Missing "nonce" in request.') + raise InvalidRequestError("Missing 'nonce' in request.") return True if exists_nonce(nonce, request): - raise InvalidRequestError('Replay attack') + raise InvalidRequestError("Replay attack") def generate_id_token( - token, user_info, key, alg, iss, aud, exp, - nonce=None, auth_time=None, code=None): + token, + user_info, + key, + iss, + aud, + alg="RS256", + exp=3600, + nonce=None, + auth_time=None, + acr=None, + amr=None, + code=None, + kid=None, +): + now = int(time.time()) + if auth_time is None: + auth_time = now + + header = {"alg": alg} + if kid: + header["kid"] = kid + + payload = { + "iss": iss, + "aud": aud, + "iat": now, + "exp": now + exp, + "auth_time": auth_time, + } + if nonce: + payload["nonce"] = nonce + + if acr: + payload["acr"] = acr + + if amr: + payload["amr"] = amr + + if code: + c_hash = create_half_hash(code, alg) + if c_hash is not None: + payload["c_hash"] = to_native(c_hash) + + access_token = token.get("access_token") + if access_token: + at_hash = create_half_hash(access_token, alg) + if at_hash is not None: + payload["at_hash"] = to_native(at_hash) - payload = _generate_id_token_payload( - alg=alg, iss=iss, aud=aud, exp=exp, nonce=nonce, - auth_time=auth_time, code=code, - access_token=token.get('access_token'), - ) payload.update(user_info) - return _jwt_encode(alg, payload, key) + if alg == "none": + private_key = None + else: + private_key = import_any_key(key) + + return jwt.encode(header, payload, private_key, [alg]) def create_response_mode_response(redirect_uri, params, response_mode): - if response_mode == 'form_post': + if response_mode == "form_post": tpl = ( - 'Redirecting' + "Redirecting" '' '
{}
' ) - inputs = ''.join([ - ''.format( - quote_url(k), quote_url(v)) - for k, v in params - ]) + inputs = "".join( + [ + f'' + for k, v in params + ] + ) body = tpl.format(quote_url(redirect_uri), inputs) - return 200, body, [('Content-Type', 'text/html; charset=utf-8')] + return 200, body, [("Content-Type", "text/html; charset=utf-8")] - if response_mode == 'query': + if response_mode == "query": uri = add_params_to_uri(redirect_uri, params, fragment=False) - elif response_mode == 'fragment': + elif response_mode == "fragment": uri = add_params_to_uri(redirect_uri, params, fragment=True) else: raise InvalidRequestError('Invalid "response_mode" value') - return 302, '', [('Location', uri)] + return 302, "", [("Location", uri)] def _guess_prompt_value(end_user, prompts, redirect_uri, redirect_fragment): # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - if not end_user and 'login' in prompts: - return 'login' + if not end_user or "login" in prompts: + return "login" - if 'consent' in prompts: + if "consent" in prompts: if not end_user: raise ConsentRequiredError( - redirect_uri=redirect_uri, - redirect_fragment=redirect_fragment) - return 'consent' - elif 'select_account' in prompts: + redirect_uri=redirect_uri, redirect_fragment=redirect_fragment + ) + return "consent" + elif "select_account" in prompts: if not end_user: raise AccountSelectionRequiredError( - redirect_uri=redirect_uri, - redirect_fragment=redirect_fragment) - return 'select_account' - - -def _generate_id_token_payload( - alg, iss, aud, exp, nonce=None, auth_time=None, - code=None, access_token=None): - now = int(time.time()) - if auth_time is None: - auth_time = now - - payload = { - 'iss': iss, - 'aud': aud, - 'iat': now, - 'exp': now + exp, - 'auth_time': auth_time, - } - if nonce: - payload['nonce'] = nonce - - if code: - payload['c_hash'] = to_native(create_half_hash(code, alg)) - - if access_token: - payload['at_hash'] = to_native(create_half_hash(access_token, alg)) - return payload - - -def _jwt_encode(alg, payload, key): - jwt = JWT(algorithms=alg) - header = {'alg': alg} - if isinstance(key, dict): - # JWK set format - if 'keys' in key: - key = random.choice(key['keys']) - header['kid'] = key['kid'] - elif 'kid' in key: - header['kid'] = key['kid'] - - return to_native(jwt.encode(header, payload, key)) + redirect_uri=redirect_uri, redirect_fragment=redirect_fragment + ) + return "select_account" diff --git a/authlib/oidc/core/models.py b/authlib/oidc/core/models.py index 5f4140507..4350e9196 100644 --- a/authlib/oidc/core/models.py +++ b/authlib/oidc/core/models.py @@ -1,13 +1,29 @@ -from authlib.oauth2.rfc6749 import ( - AuthorizationCodeMixin as _AuthorizationCodeMixin -) +from authlib.oauth2.rfc6749 import AuthorizationCodeMixin as _AuthorizationCodeMixin class AuthorizationCodeMixin(_AuthorizationCodeMixin): def get_nonce(self): """Get "nonce" value of the authorization code object.""" + # OPs MUST support the prompt parameter, as defined in Section 3.1.2, including the specified user interface behaviors such as none and login. raise NotImplementedError() def get_auth_time(self): """Get "auth_time" value of the authorization code object.""" + # OPs MUST support returning the time at which the End-User authenticated via the auth_time Claim, when requested, as defined in Section 2. raise NotImplementedError() + + def get_acr(self) -> str: + """Get the "acr" (Authentication Method Class) value of the authorization code object.""" + # OPs MUST support requests for specific Authentication Context Class Reference values via the acr_values parameter, as defined in Section 3.1.2. (Note that the minimum level of support required for this parameter is simply to have its use not result in an error.) + return None + + def get_amr(self) -> list[str]: + """Get the "amr" (Authentication Method Reference) value of the authorization code object. + + Have a look at :rfc:`RFC8176 <8176>` to see the full list of registered amr. + + def get_amr(self) -> list[str]: + return ["pwd", "otp"] + + """ + return None diff --git a/authlib/oidc/core/userinfo.py b/authlib/oidc/core/userinfo.py new file mode 100644 index 000000000..8c0ab8c0f --- /dev/null +++ b/authlib/oidc/core/userinfo.py @@ -0,0 +1,133 @@ +from joserfc import jwt +from joserfc.jws import JWSRegistry + +from authlib._joserfc_helpers import import_any_key +from authlib.consts import default_json_headers +from authlib.oauth2.rfc6749.authorization_server import AuthorizationServer +from authlib.oauth2.rfc6749.authorization_server import OAuth2Request +from authlib.oauth2.rfc6749.resource_protector import ResourceProtector + +from .claims import UserInfo + + +class UserInfoEndpoint: + """OpenID Connect Core UserInfo Endpoint. + + This endpoint returns information about a given user, as a JSON payload or as a JWT. + It must be subclassed and a few methods needs to be manually implemented:: + + class UserInfoEndpoint(oidc.core.UserInfoEndpoint): + def get_issuer(self): + return "https://auth.example" + + def generate_user_info(self, user, scope): + return UserInfo( + sub=user.id, + name=user.name, + ... + ).filter(scope) + + def resolve_private_key(self): + return server_private_jwk_set() + + It is also needed to pass a :class:`~authlib.oauth2.rfc6749.ResourceProtector` instance + with a registered :class:`~authlib.oauth2.rfc6749.TokenValidator` at initialization, + so the access to the endpoint can be restricter to valid token bearers:: + + resource_protector = ResourceProtector() + resource_protector.register_token_validator(BearerTokenValidator()) + server.register_endpoint( + UserInfoEndpoint(resource_protector=resource_protector) + ) + + And then you can plug the endpoint to your application:: + + @app.route("/oauth/userinfo", methods=["GET", "POST"]) + def userinfo(): + return server.create_endpoint_response("userinfo") + + """ + + ENDPOINT_NAME = "userinfo" + + def __init__( + self, + server: AuthorizationServer | None = None, + resource_protector: ResourceProtector | None = None, + ): + self.server = server + self.resource_protector = resource_protector + + def create_endpoint_request(self, request: OAuth2Request): + return self.server.create_oauth2_request(request) + + def __call__(self, request: OAuth2Request): + token = self.resource_protector.acquire_token("openid") + client = token.get_client() + user = token.get_user() + user_info = self.generate_user_info(user, token.scope) + + if alg := client.client_metadata.get("userinfo_signed_response_alg"): + # If signed, the UserInfo Response MUST contain the Claims iss + # (issuer) and aud (audience) as members. The iss value MUST be + # the OP's Issuer Identifier URL. The aud value MUST be or + # include the RP's Client ID value. + user_info["iss"] = self.get_issuer() + user_info["aud"] = client.client_id + + key = import_any_key(self.resolve_private_key()) + algorithms = self.get_supported_algorithms() + data = jwt.encode({"alg": alg}, user_info, key, algorithms) + return 200, data, [("Content-Type", "application/jwt")] + + return 200, user_info, default_json_headers + + def get_supported_algorithms(self) -> list[str]: + """Return the supported algorithms for userinfo signing. + By default, it uses the recommended algorithms from ``joserfc``. + Developer can override this method to customize the supported algorithms:: + + def get_supported_algorithms(self) -> list[str]: + return ["RS256"] + """ + return JWSRegistry.recommended + + def generate_user_info(self, user, scope: str) -> UserInfo: + """ + Generate a :class:`~authlib.oidc.core.UserInfo` object for an user:: + + def generate_user_info(self, user, scope: str) -> UserInfo: + return UserInfo( + given_name=user.given_name, + family_name=user.last_name, + email=user.email, + ... + ).filter(scope) + + This method must be implemented by developers. + """ + raise NotImplementedError() + + def get_issuer(self) -> str: + """The OP's Issuer Identifier URL. + + The value is used to fill the ``iss`` claim that is mandatory in signed userinfo:: + + def get_issuer(self) -> str: + return "https://auth.example" + + This method must be implemented by developers to support JWT userinfo. + """ + raise NotImplementedError() + + def resolve_private_key(self): + """Return the server JSON Web Key Set. + + This is used to sign userinfo payloads:: + + def resolve_private_key(self): + return server_private_jwk_set() + + This method must be implemented by developers to support JWT userinfo signing. + """ + return None # pragma: no cover diff --git a/authlib/oidc/core/util.py b/authlib/oidc/core/util.py index 37d23dedb..9463f95f2 100644 --- a/authlib/oidc/core/util.py +++ b/authlib/oidc/core/util.py @@ -1,12 +1,18 @@ import hashlib -from authlib.common.encoding import to_bytes, urlsafe_b64encode + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import urlsafe_b64encode def create_half_hash(s, alg): - hash_type = 'sha{}'.format(alg[2:]) - hash_alg = getattr(hashlib, hash_type, None) - if not hash_alg: - return None + if alg == "EdDSA": + hash_alg = hashlib.sha512 + else: + hash_type = f"sha{alg[2:]}" + hash_alg = getattr(hashlib, hash_type, None) + if not hash_alg: + return None + data_digest = hash_alg(to_bytes(s)).digest() slice_index = int(len(data_digest) / 2) return urlsafe_b64encode(data_digest[:slice_index]) diff --git a/authlib/oidc/discovery/__init__.py b/authlib/oidc/discovery/__init__.py index 1e76401bd..8c9822015 100644 --- a/authlib/oidc/discovery/__init__.py +++ b/authlib/oidc/discovery/__init__.py @@ -1,13 +1,12 @@ -""" - authlib.oidc.discover - ~~~~~~~~~~~~~~~~~~~~~ +"""authlib.oidc.discover. +~~~~~~~~~~~~~~~~~~~~~ - OpenID Connect Discovery 1.0 Implementation. +OpenID Connect Discovery 1.0 Implementation. - https://openid.net/specs/openid-connect-discovery-1_0.html +https://openid.net/specs/openid-connect-discovery-1_0.html """ from .models import OpenIDProviderMetadata from .well_known import get_well_known_url -__all__ = ['OpenIDProviderMetadata', 'get_well_known_url'] +__all__ = ["OpenIDProviderMetadata", "get_well_known_url"] diff --git a/authlib/oidc/discovery/models.py b/authlib/oidc/discovery/models.py index db1a8046d..00300b5a4 100644 --- a/authlib/oidc/discovery/models.py +++ b/authlib/oidc/discovery/models.py @@ -1,38 +1,57 @@ from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from authlib.oauth2.rfc8414.models import validate_array_value +from authlib.oauth2.rfc8414.models import validate_boolean_value class OpenIDProviderMetadata(AuthorizationServerMetadata): - REGISTRY_KEYS = [ - 'issuer', 'authorization_endpoint', 'token_endpoint', - 'jwks_uri', 'registration_endpoint', 'scopes_supported', - 'response_types_supported', 'response_modes_supported', - 'grant_types_supported', - 'token_endpoint_auth_methods_supported', - 'token_endpoint_auth_signing_alg_values_supported', - 'service_documentation', 'ui_locales_supported', - 'op_policy_uri', 'op_tos_uri', + """OpenID Provider Metadata for OpenID Connect Discovery. - # added by OpenID - 'acr_values_supported', 'subject_types_supported', - 'id_token_signing_alg_values_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'display_values_supported', - 'claim_types_supported', - 'claims_supported', - 'claims_locales_supported', - 'claims_parameter_supported', - 'request_parameter_supported', - 'request_uri_parameter_supported', - 'require_request_uri_registration', + The :meth:`validate` method can compose extension classes via the + ``metadata_classes`` parameter. For example, to validate RP-Initiated + Logout metadata:: + + from authlib.oidc import discovery, rpinitiated + metadata = discovery.OpenIDProviderMetadata(data) + metadata.validate(metadata_classes=[rpinitiated.OpenIDProviderMetadata]) + """ + + REGISTRY_KEYS = [ + "issuer", + "authorization_endpoint", + "token_endpoint", + "jwks_uri", + "registration_endpoint", + "scopes_supported", + "response_types_supported", + "response_modes_supported", + "grant_types_supported", + "token_endpoint_auth_methods_supported", + "service_documentation", + "ui_locales_supported", + "op_policy_uri", + "op_tos_uri", + # added by OpenID + "token_endpoint_auth_signing_alg_values_supported", + "acr_values_supported", + "subject_types_supported", + "id_token_signing_alg_values_supported", + "id_token_encryption_alg_values_supported", + "id_token_encryption_enc_values_supported", + "userinfo_signing_alg_values_supported", + "userinfo_encryption_alg_values_supported", + "userinfo_encryption_enc_values_supported", + "request_object_signing_alg_values_supported", + "request_object_encryption_alg_values_supported", + "request_object_encryption_enc_values_supported", + "display_values_supported", + "claim_types_supported", + "claims_supported", + "claims_locales_supported", + "claims_parameter_supported", + "request_parameter_supported", + "request_uri_parameter_supported", + "require_request_uri_registration", # not defined by OpenID # 'revocation_endpoint', # 'revocation_endpoint_auth_methods_supported', @@ -45,23 +64,23 @@ class OpenIDProviderMetadata(AuthorizationServerMetadata): def validate_jwks_uri(self): # REQUIRED in OpenID Connect - jwks_uri = self.get('jwks_uri') + jwks_uri = self.get("jwks_uri") if jwks_uri is None: raise ValueError('"jwks_uri" is required') - return super(OpenIDProviderMetadata, self).validate_jwks_uri() + return super().validate_jwks_uri() def validate_acr_values_supported(self): """OPTIONAL. JSON array containing a list of the Authentication Context Class References that this OP supports. """ - validate_array_value(self, 'acr_values_supported') + validate_array_value(self, "acr_values_supported") def validate_subject_types_supported(self): """REQUIRED. JSON array containing a list of the Subject Identifier types that this OP supports. Valid types include pairwise and public. """ # 1. REQUIRED - values = self.get('subject_types_supported') + values = self.get("subject_types_supported") if values is None: raise ValueError('"subject_types_supported" is required') @@ -70,10 +89,9 @@ def validate_subject_types_supported(self): raise ValueError('"subject_types_supported" MUST be JSON array') # 3. Valid types include pairwise and public - valid_types = {'pairwise', 'public'} + valid_types = {"pairwise", "public"} if not valid_types.issuperset(set(values)): - raise ValueError( - '"subject_types_supported" contains invalid values') + raise ValueError('"subject_types_supported" contains invalid values') def validate_id_token_signing_alg_values_supported(self): """REQUIRED. JSON array containing a list of the JWS signing @@ -85,53 +103,56 @@ def validate_id_token_signing_alg_values_supported(self): Code Flow). """ # 1. REQUIRED - values = self.get('id_token_signing_alg_values_supported') + values = self.get("id_token_signing_alg_values_supported") if values is None: raise ValueError('"id_token_signing_alg_values_supported" is required') # 2. JSON array if not isinstance(values, list): - raise ValueError('"id_token_signing_alg_values_supported" MUST be JSON array') + raise ValueError( + '"id_token_signing_alg_values_supported" MUST be JSON array' + ) # 3. The algorithm RS256 MUST be included - if 'RS256' not in values: + if "RS256" not in values: raise ValueError( - '"RS256" MUST be included in "id_token_signing_alg_values_supported"') + '"RS256" MUST be included in "id_token_signing_alg_values_supported"' + ) def validate_id_token_encryption_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWE encryption algorithms (alg values) supported by the OP for the ID Token to encode the Claims in a JWT. """ - validate_array_value(self, 'id_token_encryption_alg_values_supported') + validate_array_value(self, "id_token_encryption_alg_values_supported") def validate_id_token_encryption_enc_values_supported(self): """OPTIONAL. JSON array containing a list of the JWE encryption algorithms (enc values) supported by the OP for the ID Token to encode the Claims in a JWT. """ - validate_array_value(self, 'id_token_encryption_enc_values_supported') + validate_array_value(self, "id_token_encryption_enc_values_supported") def validate_userinfo_signing_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWS signing algorithms (alg values) [JWA] supported by the UserInfo Endpoint to encode the Claims in a JWT. The value none MAY be included. """ - validate_array_value(self, 'userinfo_signing_alg_values_supported') + validate_array_value(self, "userinfo_signing_alg_values_supported") def validate_userinfo_encryption_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWE encryption algorithms (alg values) [JWA] supported by the UserInfo Endpoint to encode the Claims in a JWT. """ - validate_array_value(self, 'userinfo_encryption_alg_values_supported') + validate_array_value(self, "userinfo_encryption_alg_values_supported") def validate_userinfo_encryption_enc_values_supported(self): """OPTIONAL. JSON array containing a list of the JWE encryption algorithms (enc values) [JWA] supported by the UserInfo Endpoint to encode the Claims in a JWT. """ - validate_array_value(self, 'userinfo_encryption_enc_values_supported') + validate_array_value(self, "userinfo_encryption_enc_values_supported") def validate_request_object_signing_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWS signing @@ -142,18 +163,14 @@ def validate_request_object_signing_alg_values_supported(self): reference (using the request_uri parameter). Servers SHOULD support none and RS256. """ - values = self.get('request_object_signing_alg_values_supported') + values = self.get("request_object_signing_alg_values_supported") if not values: return if not isinstance(values, list): - raise ValueError('"request_object_signing_alg_values_supported" MUST be JSON array') - - # Servers SHOULD support none and RS256 - if 'none' not in values or 'RS256' not in values: raise ValueError( - '"request_object_signing_alg_values_supported" ' - 'SHOULD support none and RS256') + '"request_object_signing_alg_values_supported" MUST be JSON array' + ) def validate_request_object_encryption_alg_values_supported(self): """OPTIONAL. JSON array containing a list of the JWE encryption @@ -161,7 +178,7 @@ def validate_request_object_encryption_alg_values_supported(self): These algorithms are used both when the Request Object is passed by value and when it is passed by reference. """ - validate_array_value(self, 'request_object_encryption_alg_values_supported') + validate_array_value(self, "request_object_encryption_alg_values_supported") def validate_request_object_encryption_enc_values_supported(self): """OPTIONAL. JSON array containing a list of the JWE encryption @@ -169,21 +186,21 @@ def validate_request_object_encryption_enc_values_supported(self): These algorithms are used both when the Request Object is passed by value and when it is passed by reference. """ - validate_array_value(self, 'request_object_encryption_enc_values_supported') + validate_array_value(self, "request_object_encryption_enc_values_supported") def validate_display_values_supported(self): """OPTIONAL. JSON array containing a list of the display parameter values that the OpenID Provider supports. These values are described in Section 3.1.2.1 of OpenID Connect Core 1.0. """ - values = self.get('display_values_supported') + values = self.get("display_values_supported") if not values: return if not isinstance(values, list): raise ValueError('"display_values_supported" MUST be JSON array') - valid_values = {'page', 'popup', 'touch', 'wap'} + valid_values = {"page", "popup", "touch", "wap"} if not valid_values.issuperset(set(values)): raise ValueError('"display_values_supported" contains invalid values') @@ -194,14 +211,14 @@ def validate_claim_types_supported(self): specification are normal, aggregated, and distributed. If omitted, the implementation supports only normal Claims. """ - values = self.get('claim_types_supported') + values = self.get("claim_types_supported") if not values: return if not isinstance(values, list): raise ValueError('"claim_types_supported" MUST be JSON array') - valid_values = {'normal', 'aggregated', 'distributed'} + valid_values = {"normal", "aggregated", "distributed"} if not valid_values.issuperset(set(values)): raise ValueError('"claim_types_supported" contains invalid values') @@ -211,7 +228,7 @@ def validate_claims_supported(self): for. Note that for privacy or other reasons, this might not be an exhaustive list. """ - validate_array_value(self, 'claims_supported') + validate_array_value(self, "claims_supported") def validate_claims_locales_supported(self): """OPTIONAL. Languages and scripts supported for values in Claims @@ -219,28 +236,28 @@ def validate_claims_locales_supported(self): language tag values. Not all languages and scripts are necessarily supported for all Claim values. """ - validate_array_value(self, 'claims_locales_supported') + validate_array_value(self, "claims_locales_supported") def validate_claims_parameter_supported(self): """OPTIONAL. Boolean value specifying whether the OP supports use of the claims parameter, with true indicating support. If omitted, the default value is false. """ - _validate_boolean_value(self, 'claims_parameter_supported') + validate_boolean_value(self, "claims_parameter_supported") def validate_request_parameter_supported(self): """OPTIONAL. Boolean value specifying whether the OP supports use of the request parameter, with true indicating support. If omitted, the default value is false. """ - _validate_boolean_value(self, 'request_parameter_supported') + validate_boolean_value(self, "request_parameter_supported") def validate_request_uri_parameter_supported(self): """OPTIONAL. Boolean value specifying whether the OP supports use of the request_uri parameter, with true indicating support. If omitted, the default value is true. """ - _validate_boolean_value(self, 'request_uri_parameter_supported') + validate_boolean_value(self, "request_uri_parameter_supported") def validate_require_request_uri_registration(self): """OPTIONAL. Boolean value specifying whether the OP requires any @@ -248,36 +265,29 @@ def validate_require_request_uri_registration(self): registration parameter. Pre-registration is REQUIRED when the value is true. If omitted, the default value is false. """ - _validate_boolean_value(self, 'require_request_uri_registration') + validate_boolean_value(self, "require_request_uri_registration") @property def claim_types_supported(self): # If omitted, the implementation supports only normal Claims - return self.get('claim_types_supported', ['normal']) + return self.get("claim_types_supported", ["normal"]) @property def claims_parameter_supported(self): # If omitted, the default value is false. - return self.get('claims_parameter_supported', False) + return self.get("claims_parameter_supported", False) @property def request_parameter_supported(self): # If omitted, the default value is false. - return self.get('request_parameter_supported', False) + return self.get("request_parameter_supported", False) @property def request_uri_parameter_supported(self): # If omitted, the default value is true. - return self.get('request_uri_parameter_supported', True) + return self.get("request_uri_parameter_supported", True) @property def require_request_uri_registration(self): # If omitted, the default value is false. - return self.get('require_request_uri_registration', False) - - -def _validate_boolean_value(metadata, key): - if key not in metadata: - return - if metadata[key] not in (True, False): - raise ValueError('"{}" MUST be boolean'.format(key)) + return self.get("require_request_uri_registration", False) diff --git a/authlib/oidc/discovery/well_known.py b/authlib/oidc/discovery/well_known.py index e3087a143..0222962d0 100644 --- a/authlib/oidc/discovery/well_known.py +++ b/authlib/oidc/discovery/well_known.py @@ -10,8 +10,8 @@ def get_well_known_url(issuer, external=False): """ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest if external: - return issuer.rstrip('/') + '/.well-known/openid-configuration' + return issuer.rstrip("/") + "/.well-known/openid-configuration" parsed = urlparse.urlparse(issuer) path = parsed.path - return path.rstrip('/') + '/.well-known/openid-configuration' + return path.rstrip("/") + "/.well-known/openid-configuration" diff --git a/authlib/oidc/registration/__init__.py b/authlib/oidc/registration/__init__.py new file mode 100644 index 000000000..08cbf656c --- /dev/null +++ b/authlib/oidc/registration/__init__.py @@ -0,0 +1,3 @@ +from .claims import ClientMetadataClaims + +__all__ = ["ClientMetadataClaims"] diff --git a/authlib/oidc/registration/claims.py b/authlib/oidc/registration/claims.py new file mode 100644 index 000000000..a6fc2d071 --- /dev/null +++ b/authlib/oidc/registration/claims.py @@ -0,0 +1,334 @@ +from joserfc.errors import InvalidClaimError + +from authlib.common.urls import is_valid_url +from authlib.oauth2.claims import BaseClaims + + +class ClientMetadataClaims(BaseClaims): + REGISTERED_CLAIMS = [ + "token_endpoint_auth_signing_alg", + "application_type", + "sector_identifier_uri", + "subject_type", + "id_token_signed_response_alg", + "id_token_encrypted_response_alg", + "id_token_encrypted_response_enc", + "userinfo_signed_response_alg", + "userinfo_encrypted_response_alg", + "userinfo_encrypted_response_enc", + "default_max_age", + "require_auth_time", + "default_acr_values", + "initiate_login_uri", + "request_object_signing_alg", + "request_object_encryption_alg", + "request_object_encryption_enc", + "request_uris", + ] + + def validate(self, now=None, leeway=0): + super().validate(now, leeway) + self.validate_token_endpoint_auth_signing_alg() + self.validate_application_type() + self.validate_sector_identifier_uri() + self.validate_subject_type() + self.validate_id_token_signed_response_alg() + self.validate_id_token_encrypted_response_alg() + self.validate_id_token_encrypted_response_enc() + self.validate_userinfo_signed_response_alg() + self.validate_userinfo_encrypted_response_alg() + self.validate_userinfo_encrypted_response_enc() + self.validate_default_max_age() + self.validate_require_auth_time() + self.validate_default_acr_values() + self.validate_initiate_login_uri() + self.validate_request_object_signing_alg() + self.validate_request_object_encryption_alg() + self.validate_request_object_encryption_enc() + self.validate_request_uris() + + def _validate_uri(self, key): + uri = self.get(key) + uris = uri if isinstance(uri, list) else [uri] + for uri in uris: + if uri and not is_valid_url(uri): + raise InvalidClaimError(key) + + @classmethod + def get_claims_options(self, metadata): + """Generate claims options validation from Authorization Server metadata.""" + options = {} + + if acr_values_supported := metadata.get("acr_values_supported"): + + def _validate_default_acr_values(claims, value): + return not value or set(value).issubset(set(acr_values_supported)) + + options["default_acr_values"] = {"validate": _validate_default_acr_values} + + values_mapping = { + "token_endpoint_auth_signing_alg_values_supported": "token_endpoint_auth_signing_alg", + "subject_types_supported": "subject_type", + "id_token_signing_alg_values_supported": "id_token_signed_response_alg", + "id_token_encryption_alg_values_supported": "id_token_encrypted_response_alg", + "id_token_encryption_enc_values_supported": "id_token_encrypted_response_enc", + "userinfo_signing_alg_values_supported": "userinfo_signed_response_alg", + "userinfo_encryption_alg_values_supported": "userinfo_encrypted_response_alg", + "userinfo_encryption_enc_values_supported": "userinfo_encrypted_response_enc", + "request_object_signing_alg_values_supported": "request_object_signing_alg", + "request_object_encryption_alg_values_supported": "request_object_encryption_alg", + "request_object_encryption_enc_values_supported": "request_object_encryption_enc", + } + + def make_validator(metadata_claim_values): + def _validate(claims, value): + return not value or value in metadata_claim_values + + return _validate + + for metadata_claim_name, request_claim_name in values_mapping.items(): + if metadata_claim_values := metadata.get(metadata_claim_name): + options[request_claim_name] = { + "validate": make_validator(metadata_claim_values) + } + + return options + + def validate_token_endpoint_auth_signing_alg(self): + """JWS [JWS] alg algorithm [JWA] that MUST be used for signing the JWT [JWT] + used to authenticate the Client at the Token Endpoint for the private_key_jwt + and client_secret_jwt authentication methods. + + All Token Requests using these authentication methods from this Client MUST be + rejected, if the JWT is not signed with this algorithm. Servers SHOULD support + RS256. The value none MUST NOT be used. The default, if omitted, is that any + algorithm supported by the OP and the RP MAY be used. + """ + if self.get("token_endpoint_auth_signing_alg") == "none": + raise InvalidClaimError("token_endpoint_auth_signing_alg") + + def validate_application_type(self): + """Kind of the application. + + The default, if omitted, is web. The defined values are native or web. Web + Clients using the OAuth Implicit Grant Type MUST only register URLs using the + https scheme as redirect_uris; they MUST NOT use localhost as the hostname. + Native Clients MUST only register redirect_uris using custom URI schemes or + loopback URLs using the http scheme; loopback URLs use localhost or the IP + loopback literals 127.0.0.1 or [::1] as the hostname. Authorization Servers MAY + place additional constraints on Native Clients. Authorization Servers MAY + reject Redirection URI values using the http scheme, other than the loopback + case for Native Clients. The Authorization Server MUST verify that all the + registered redirect_uris conform to these constraints. This prevents sharing a + Client ID across different types of Clients. + """ + self.setdefault("application_type", "web") + if self.get("application_type") not in ("web", "native"): + raise InvalidClaimError("application_type") + + def validate_sector_identifier_uri(self): + """URL using the https scheme to be used in calculating Pseudonymous Identifiers + by the OP. + + The URL references a file with a single JSON array of redirect_uri values. + Please see Section 5. Providers that use pairwise sub (subject) values SHOULD + utilize the sector_identifier_uri value provided in the Subject Identifier + calculation for pairwise identifiers. + """ + self._validate_uri("sector_identifier_uri") + + def validate_subject_type(self): + """subject_type requested for responses to this Client. + + The subject_types_supported discovery parameter contains a list of the supported + subject_type values for the OP. Valid types include pairwise and public. + """ + + def validate_id_token_signed_response_alg(self): + """JWS alg algorithm [JWA] REQUIRED for signing the ID Token issued to this + Client. + + The value none MUST NOT be used as the ID Token alg value unless the Client uses + only Response Types that return no ID Token from the Authorization Endpoint + (such as when only using the Authorization Code Flow). The default, if omitted, + is RS256. The public key for validating the signature is provided by retrieving + the JWK Set referenced by the jwks_uri element from OpenID Connect Discovery 1.0 + [OpenID.Discovery]. + """ + if self.get( + "id_token_signed_response_alg" + ) == "none" and "id_token" in self.get("response_type", ""): + raise InvalidClaimError("id_token_signed_response_alg") + + self.setdefault("id_token_signed_response_alg", "RS256") + + def validate_id_token_encrypted_response_alg(self): + """JWE alg algorithm [JWA] REQUIRED for encrypting the ID Token issued to this + Client. + + If this is requested, the response will be signed then encrypted, with the + result being a Nested JWT, as defined in [JWT]. The default, if omitted, is that + no encryption is performed. + """ + + def validate_id_token_encrypted_response_enc(self): + """JWE enc algorithm [JWA] REQUIRED for encrypting the ID Token issued to this + Client. + + If id_token_encrypted_response_alg is specified, the default + id_token_encrypted_response_enc value is A128CBC-HS256. When + id_token_encrypted_response_enc is included, id_token_encrypted_response_alg + MUST also be provided. + """ + if self.get("id_token_encrypted_response_enc") and not self.get( + "id_token_encrypted_response_alg" + ): + raise InvalidClaimError("id_token_encrypted_response_enc") + + if self.get("id_token_encrypted_response_alg"): + self.setdefault("id_token_encrypted_response_enc", "A128CBC-HS256") + + def validate_userinfo_signed_response_alg(self): + """JWS alg algorithm [JWA] REQUIRED for signing UserInfo Responses. + + If this is specified, the response will be JWT [JWT] serialized, and signed + using JWS. The default, if omitted, is for the UserInfo Response to return the + Claims as a UTF-8 [RFC3629] encoded JSON object using the application/json + content-type. + """ + + def validate_userinfo_encrypted_response_alg(self): + """JWE [JWE] alg algorithm [JWA] REQUIRED for encrypting UserInfo Responses. + + If both signing and encryption are requested, the response will be signed then + encrypted, with the result being a Nested JWT, as defined in [JWT]. The default, + if omitted, is that no encryption is performed. + """ + + def validate_userinfo_encrypted_response_enc(self): + """JWE enc algorithm [JWA] REQUIRED for encrypting UserInfo Responses. + + If userinfo_encrypted_response_alg is specified, the default + userinfo_encrypted_response_enc value is A128CBC-HS256. When + userinfo_encrypted_response_enc is included, userinfo_encrypted_response_alg + MUST also be provided. + """ + if self.get("userinfo_encrypted_response_enc") and not self.get( + "userinfo_encrypted_response_alg" + ): + raise InvalidClaimError("userinfo_encrypted_response_enc") + + if self.get("userinfo_encrypted_response_alg"): + self.setdefault("userinfo_encrypted_response_enc", "A128CBC-HS256") + + def validate_default_max_age(self): + """Default Maximum Authentication Age. + + Specifies that the End-User MUST be actively authenticated if the End-User was + authenticated longer ago than the specified number of seconds. The max_age + request parameter overrides this default value. If omitted, no default Maximum + Authentication Age is specified. + """ + if self.get("default_max_age") is not None and not isinstance( + self["default_max_age"], (int, float) + ): + raise InvalidClaimError("default_max_age") + + def validate_require_auth_time(self): + """Boolean value specifying whether the auth_time Claim in the ID Token is + REQUIRED. + + It is REQUIRED when the value is true. (If this is false, the auth_time Claim + can still be dynamically requested as an individual Claim for the ID Token using + the claims request parameter described in Section 5.5.1 of OpenID Connect Core + 1.0 [OpenID.Core].) If omitted, the default value is false. + """ + self.setdefault("require_auth_time", False) + if self.get("require_auth_time") is not None and not isinstance( + self["require_auth_time"], bool + ): + raise InvalidClaimError("require_auth_time") + + def validate_default_acr_values(self): + """Default requested Authentication Context Class Reference values. + + Array of strings that specifies the default acr values that the OP is being + requested to use for processing requests from this Client, with the values + appearing in order of preference. The Authentication Context Class satisfied by + the authentication performed is returned as the acr Claim Value in the issued ID + Token. The acr Claim is requested as a Voluntary Claim by this parameter. The + acr_values_supported discovery element contains a list of the supported acr + values supported by the OP. Values specified in the acr_values request parameter + or an individual acr Claim request override these default values. + """ + + def validate_initiate_login_uri(self): + """RI using the https scheme that a third party can use to initiate a login by + the RP, as specified in Section 4 of OpenID Connect Core 1.0 [OpenID.Core]. + + The URI MUST accept requests via both GET and POST. The Client MUST understand + the login_hint and iss parameters and SHOULD support the target_link_uri + parameter. + """ + self._validate_uri("initiate_login_uri") + + def validate_request_object_signing_alg(self): + """JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects + sent to the OP. + + All Request Objects from this Client MUST be rejected, if not signed with this + algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core + 1.0 [OpenID.Core]. This algorithm MUST be used both when the Request Object is + passed by value (using the request parameter) and when it is passed by reference + (using the request_uri parameter). Servers SHOULD support RS256. The value none + MAY be used. The default, if omitted, is that any algorithm supported by the OP + and the RP MAY be used. + """ + + def validate_request_object_encryption_alg(self): + """JWE [JWE] alg algorithm [JWA] the RP is declaring that it may use for + encrypting Request Objects sent to the OP. + + This parameter SHOULD be included when symmetric encryption will be used, since + this signals to the OP that a client_secret value needs to be returned from + which the symmetric key will be derived, that might not otherwise be returned. + The RP MAY still use other supported encryption algorithms or send unencrypted + Request Objects, even when this parameter is present. If both signing and + encryption are requested, the Request Object will be signed then encrypted, with + the result being a Nested JWT, as defined in [JWT]. The default, if omitted, is + that the RP is not declaring whether it might encrypt any Request Objects. + """ + + def validate_request_object_encryption_enc(self): + """JWE enc algorithm [JWA] the RP is declaring that it may use for encrypting + Request Objects sent to the OP. + + If request_object_encryption_alg is specified, the default + request_object_encryption_enc value is A128CBC-HS256. When + request_object_encryption_enc is included, request_object_encryption_alg MUST + also be provided. + """ + if self.get("request_object_encryption_enc") and not self.get( + "request_object_encryption_alg" + ): + raise InvalidClaimError("request_object_encryption_enc") + + if self.get("request_object_encryption_alg"): + self.setdefault("request_object_encryption_enc", "A128CBC-HS256") + + def validate_request_uris(self): + """Array of request_uri values that are pre-registered by the RP for use at the + OP. + + These URLs MUST use the https scheme unless the target Request Object is signed + in a way that is verifiable by the OP. Servers MAY cache the contents of the + files referenced by these URIs and not retrieve them at the time they are used + in a request. OPs can require that request_uri values used be pre-registered + with the require_request_uri_registration discovery parameter. If the contents + of the request file could ever change, these URI values SHOULD include the + base64url-encoded SHA-256 hash value of the file contents referenced by the URI + as the value of the URI fragment. If the fragment value used for a URI changes, + that signals the server that its cached value for that URI with the old fragment + value is no longer valid. + """ + self._validate_uri("request_uris") diff --git a/authlib/oidc/rpinitiated/__init__.py b/authlib/oidc/rpinitiated/__init__.py new file mode 100644 index 000000000..646c6969b --- /dev/null +++ b/authlib/oidc/rpinitiated/__init__.py @@ -0,0 +1,19 @@ +"""authlib.oidc.rpinitiated. +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +OpenID Connect RP-Initiated Logout 1.0 Implementation. + +https://openid.net/specs/openid-connect-rpinitiated-1_0.html +""" + +from .discovery import OpenIDProviderMetadata +from .end_session import EndSessionEndpoint +from .end_session import EndSessionRequest +from .registration import ClientMetadataClaims + +__all__ = [ + "EndSessionEndpoint", + "EndSessionRequest", + "ClientMetadataClaims", + "OpenIDProviderMetadata", +] diff --git a/authlib/oidc/rpinitiated/discovery.py b/authlib/oidc/rpinitiated/discovery.py new file mode 100644 index 000000000..c36e46a6d --- /dev/null +++ b/authlib/oidc/rpinitiated/discovery.py @@ -0,0 +1,13 @@ +from authlib.common.security import is_secure_transport + + +class OpenIDProviderMetadata(dict): + REGISTRY_KEYS = ["end_session_endpoint"] + + def validate_end_session_endpoint(self): + # rpinitiated §2.1: "end_session_endpoint - URL at the OP to which an + # RP can perform a redirect to request that the End-User be logged out + # at the OP. This URL MUST use the https scheme." + url = self.get("end_session_endpoint") + if url and not is_secure_transport(url): + raise ValueError('"end_session_endpoint" MUST use "https" scheme') diff --git a/authlib/oidc/rpinitiated/end_session.py b/authlib/oidc/rpinitiated/end_session.py new file mode 100644 index 000000000..bdd45899c --- /dev/null +++ b/authlib/oidc/rpinitiated/end_session.py @@ -0,0 +1,285 @@ +"""OpenID Connect RP-Initiated Logout 1.0 implementation. + +https://openid.net/specs/openid-connect-rpinitiated-1_0.html +""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Any + +from joserfc import jwt +from joserfc.errors import JoseError +from joserfc.jwk import KeySet +from joserfc.jws import JWSRegistry + +from authlib.common.urls import add_params_to_uri +from authlib.oauth2.rfc6749.endpoint import Endpoint +from authlib.oauth2.rfc6749.endpoint import EndpointRequest +from authlib.oauth2.rfc6749.errors import InvalidRequestError + +if TYPE_CHECKING: + from authlib.oauth2.rfc6749.requests import OAuth2Request + + +class _NonExpiringClaimsRegistry(jwt.JWTClaimsRegistry): + """Claims registry that skips expiration validation.""" + + # rpinitiated §2: "The OP SHOULD accept ID Tokens when the RP identified by the + # ID Token's aud claim and/or sid claim has a current session or had a + # recent session at the OP, even when the exp time has passed." + def validate_exp(self, value: int) -> None: + pass + + +@dataclass +class EndSessionRequest(EndpointRequest): + """Validated end session request data. + + This object is returned by :meth:`EndSessionEndpoint.validate_request` + and contains all the validated information from the logout request. + """ + + id_token_claims: dict | None = field(default=None, repr=False) + redirect_uri: str | None = None + logout_hint: str | None = None + ui_locales: str | None = None + + @property + def needs_confirmation(self) -> bool: + """Whether user confirmation is recommended before logout.""" + + # rpinitiated §6: "Logout requests without a valid id_token_hint value are a + # potential means of denial of service; therefore, OPs should obtain + # explicit confirmation from the End-User before acting upon them." + return self.id_token_claims is None + + +class EndSessionEndpoint(Endpoint): + """OpenID Connect RP-Initiated Logout endpoint. + + This endpoint follows a two-phase pattern for interactive flows: + + 1. Call ``server.validate_endpoint_request("end_session")`` to validate + the request and get an :class:`EndSessionRequest` + 2. Check ``end_session_request.needs_confirmation`` and show UI if needed + 3. Call ``server.create_endpoint_response("end_session", end_session_request)`` + to execute logout and create the response + + Example usage:: + + class MyEndSessionEndpoint(EndSessionEndpoint): + def get_server_jwks(self): + return load_jwks() + + def end_session(self, end_session_request): + session.clear() + + + server.register_endpoint(MyEndSessionEndpoint) + + + @app.route("/logout", methods=["GET", "POST"]) + def logout(): + try: + req = server.validate_endpoint_request("end_session") + except OAuth2Error as error: + return server.handle_error_response(None, error) + + if req.needs_confirmation and request.method == "GET": + return render_template("confirm_logout.html", client=req.client) + + return server.create_endpoint_response( + "end_session", req + ) or render_template("logged_out.html") + + For non-interactive usage (no confirmation page), use the standard pattern:: + + @app.route("/logout", methods=["GET", "POST"]) + def logout(): + return server.create_endpoint_response("end_session") or render_template( + "logged_out.html" + ) + """ + + ENDPOINT_NAME = "end_session" + + def validate_request(self, request: OAuth2Request) -> EndSessionRequest: + """Validate an end session request. + + :param request: The OAuth2Request to validate + :returns: EndSessionRequest with validated data + :raises InvalidRequestError: If validation fails + """ + data = request.payload.data + + id_token_hint = data.get("id_token_hint") + client_id = data.get("client_id") + post_logout_redirect_uri = data.get("post_logout_redirect_uri") + state = data.get("state") + logout_hint = data.get("logout_hint") + ui_locales = data.get("ui_locales") + + # rpinitiated §2: "When an id_token_hint parameter is present, the OP MUST + # validate that it was the issuer of the ID Token." + id_token_claims = None + if id_token_hint: + id_token_claims = self._validate_id_token_hint(id_token_hint) + + # Resolve client + client = None + if client_id: + client = self.server.query_client(client_id) + elif id_token_claims: + client = self.resolve_client_from_id_token_claims(id_token_claims) + + # rpinitiated §2: "When both client_id and id_token_hint are present, the OP + # MUST verify that the Client Identifier matches the one used as the + # audience of the ID Token." + if client_id and id_token_claims: + aud = id_token_claims.get("aud") + aud_list = [aud] if isinstance(aud, str) else (aud or []) + if client_id not in aud_list: + raise InvalidRequestError("'client_id' does not match 'aud' claim") + + # rpinitiated §3: "The OP MUST NOT perform post-logout redirection if + # the post_logout_redirect_uri value supplied does not exactly match + # one of the previously registered post_logout_redirect_uris values." + redirect_uri = None + if ( + post_logout_redirect_uri + and client + and self._is_valid_post_logout_redirect_uri( + client, post_logout_redirect_uri + ) + and ( + id_token_claims + or self.is_post_logout_redirect_uri_legitimate( + request, post_logout_redirect_uri, client, logout_hint + ) + ) + ): + redirect_uri = post_logout_redirect_uri + # rpinitiated §3: "If the post_logout_redirect_uri value is provided + # and the preceding conditions are met, the OP MUST include the + # state value if the RP's initial Logout Request included state." + if state: + redirect_uri = add_params_to_uri(redirect_uri, {"state": state}) + + return EndSessionRequest( + request=request, + client=client, + id_token_claims=id_token_claims, + redirect_uri=redirect_uri, + logout_hint=logout_hint, + ui_locales=ui_locales, + ) + + def create_response( + self, validated_request: EndSessionRequest + ) -> tuple[int, Any, list[tuple[str, str]]] | None: + """Create the end session HTTP response. + + Executes the logout via :meth:`end_session`, then returns a redirect + response if a valid redirect_uri is present, or None to let the + application provide its own response. + + :param validated_request: The validated EndSessionRequest + :returns: Tuple of (status_code, body, headers) for redirect, or None + """ + req: EndSessionRequest = validated_request # type: ignore[assignment] + self.end_session(req) + + if req.redirect_uri: + return 302, "", [("Location", req.redirect_uri)] + return None + + def _validate_id_token_hint(self, id_token_hint: str) -> dict: + """Validate that the OP was the issuer of the ID Token.""" + # rpinitiated §2: "When an id_token_hint parameter is present, the OP MUST + # validate that it was the issuer of the ID Token." + # This is done by verifying the signature against the server's JWKS. + jwks = self.get_server_jwks() + if isinstance(jwks, dict): + jwks = KeySet.import_key_set(jwks) + + # rpinitiated §4: "When the OP detects errors in the RP-Initiated + # Logout request, the OP MUST not perform post-logout redirection." + try: + token = jwt.decode(id_token_hint, jwks, algorithms=self.get_algorithms()) + claims_registry = _NonExpiringClaimsRegistry(nbf={"essential": False}) + claims_registry.validate(token.claims) + except JoseError as exc: + raise InvalidRequestError(exc.description) from exc + + return dict(token.claims) + + def resolve_client_from_id_token_claims(self, id_token_claims: dict) -> Any | None: + """Resolve client from id_token aud claim. + + When aud is a single string, resolves the client directly. + When aud is a list, returns None (ambiguous case). + Override for custom resolution logic. + """ + aud = id_token_claims.get("aud") + if isinstance(aud, str): + return self.server.query_client(aud) + return None + + def _is_valid_post_logout_redirect_uri( + self, client, post_logout_redirect_uri: str + ) -> bool: + """Check if post_logout_redirect_uri is registered for the client.""" + registered_uris = client.client_metadata.get("post_logout_redirect_uris", []) + return post_logout_redirect_uri in registered_uris + + def is_post_logout_redirect_uri_legitimate( + self, + request: OAuth2Request, + post_logout_redirect_uri: str, + client, + logout_hint: str | None, + ) -> bool: + """Confirm redirect_uri legitimacy when no id_token_hint is provided. + + Override if you have alternative confirmation mechanisms, e.g.:: + + def is_post_logout_redirect_uri_legitimate(self, ...): + return client and client.is_trusted + + By default returns False (no redirection without id_token_hint). + """ + # rpinitiated §3: "if it is not supplied with post_logout_redirect_uri, + # the OP MUST NOT perform post-logout redirection unless the OP has + # other means of confirming the legitimacy" + return False + + def get_server_jwks(self) -> dict | KeySet: + """Return the server's JSON Web Key Set for validating ID tokens.""" + raise NotImplementedError() + + def get_algorithms(self) -> list[str]: + """Return the list of allowed algorithms for ID token validation. + + By default, returns all algorithms compatible with the keys in the JWKS. + Override to restrict to specific algorithms. + """ + jwks = self.get_server_jwks() + if isinstance(jwks, dict): + jwks = KeySet.import_key_set(jwks) + return [alg.name for alg in JWSRegistry.filter_algorithms(jwks)] + + def end_session(self, end_session_request: EndSessionRequest) -> None: + """Terminate the user's session. + + Implement this method to perform the actual logout logic, + such as clearing session data, revoking tokens, etc. + + Use ``end_session_request.logout_hint`` to help identify the user + (e.g. email, username) when no ``id_token_hint`` is provided. + + :param end_session_request: The validated EndSessionRequest + """ + raise NotImplementedError() diff --git a/authlib/oidc/rpinitiated/registration.py b/authlib/oidc/rpinitiated/registration.py new file mode 100644 index 000000000..3df8cf486 --- /dev/null +++ b/authlib/oidc/rpinitiated/registration.py @@ -0,0 +1,57 @@ +"""Client metadata for OpenID Connect RP-Initiated Logout 1.0. + +https://openid.net/specs/openid-connect-rpinitiated-1_0.html +""" + +from joserfc.errors import InvalidClaimError + +from authlib.common.security import is_secure_transport +from authlib.common.urls import is_valid_url +from authlib.oauth2.claims import BaseClaims + + +class ClientMetadataClaims(BaseClaims): + """Client metadata for OpenID Connect RP-Initiated Logout 1.0. + + This can be used with :ref:`specs/rfc7591` and :ref:`specs/rfc7592` endpoints:: + + server.register_endpoint( + ClientRegistrationEndpoint( + claims_classes=[ + rfc7591.ClientMetadataClaims, + oidc.registration.ClientMetadataClaims, + oidc.rpinitiated.ClientMetadataClaims, + ] + ) + ) + """ + + REGISTERED_CLAIMS = [ + "post_logout_redirect_uris", + ] + + def validate(self, now=None, leeway=0): + super().validate(now, leeway) + self._validate_post_logout_redirect_uris() + + def _validate_post_logout_redirect_uris(self): + # rpinitiated §3.1: "post_logout_redirect_uris - Array of URLs supplied + # by the RP to which it MAY request that the End-User's User Agent be + # redirected using the post_logout_redirect_uri parameter after a + # logout has been performed. These URLs SHOULD use the https scheme + # [...]; however, they MAY use the http scheme, provided that the + # Client Type is confidential." + uris = self.get("post_logout_redirect_uris") + if not uris: + return + + is_public = self.get("token_endpoint_auth_method") == "none" + + for uri in uris: + if not is_valid_url(uri): + raise InvalidClaimError("post_logout_redirect_uris") + + if is_public and not is_secure_transport(uri): + raise ValueError( + '"post_logout_redirect_uris" MUST use "https" scheme for public clients' + ) diff --git a/docs/_static/authlib.png b/docs/_static/authlib.png deleted file mode 100644 index c37c2a0a6..000000000 Binary files a/docs/_static/authlib.png and /dev/null differ diff --git a/docs/_static/authlib.svg b/docs/_static/authlib.svg deleted file mode 100644 index a8194bbee..000000000 --- a/docs/_static/authlib.svg +++ /dev/null @@ -1 +0,0 @@ -Authlib \ No newline at end of file diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 000000000..dd1d35e22 --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,40 @@ +:root { + --syntax-light-pre-bg: #ecf5ff; + --syntax-light-cap-bg: #d6e7fb; + --syntax-dark-pre-bg: #1a2b3e; + --syntax-dark-cap-bg: #223e5e; +} + +#ethical-ad-placement { + display: none; +} + +.site-sponsors { + margin-bottom: 2rem; +} + +.site-sponsors > .sponsor { + display: flex; + align-items: center; + background: var(--sy-c-bg-weak); + border-radius: 6px; + padding: 0.5rem; + margin-bottom: 0.5rem; +} + +.site-sponsors .image { + flex-shrink: 0; + display: block; + width: 32px; + margin-right: 0.8rem; +} + +.site-sponsors .text { + font-size: 0.86rem; + line-height: 1.2; +} + +.site-sponsors .text a { + color: var(--sy-c-link); + border-color: var(--sy-c-link); +} diff --git a/docs/_static/dark-logo.svg b/docs/_static/dark-logo.svg new file mode 100644 index 000000000..5b1adfa82 --- /dev/null +++ b/docs/_static/dark-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/favicon.ico b/docs/_static/favicon.ico deleted file mode 100644 index d275da7b6..000000000 Binary files a/docs/_static/favicon.ico and /dev/null differ diff --git a/docs/_static/icon.svg b/docs/_static/icon.svg new file mode 100644 index 000000000..974ed8fa9 --- /dev/null +++ b/docs/_static/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/light-logo.svg b/docs/_static/light-logo.svg new file mode 100644 index 000000000..f0cfb076f --- /dev/null +++ b/docs/_static/light-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/sponsors.css b/docs/_static/sponsors.css deleted file mode 100644 index e70e7692c..000000000 --- a/docs/_static/sponsors.css +++ /dev/null @@ -1,77 +0,0 @@ -.ethical-fixedfooter { display:none } -.fund{ - z-index: 1; - position: relative; - bottom: 0; - right: 0; - float: right; - padding: 0 0 20px 30px; - width: 150px; -} -.fund a { border:0 } -#carbonads { - background: #EDF2F4; - padding: 5px 10px; - border-radius: 3px; -} -#carbonads span { - display: block; -} - -#carbonads a { - color: inherit; - text-decoration: none; -} - -#carbonads a:hover { - color: inherit; -} - -#carbonads span { - position: relative; - display: block; - overflow: hidden; -} - -.carbon-img img { - display: block; - width: 130px; -} - -#carbonads .carbon-text { - display: block; - margin-top: 4px; - font-size: 13px; - text-align: left; -} - -#carbonads .carbon-poweredby { - color: #aaa; - font-size: 10px; - letter-spacing: 0.8px; - text-transform: uppercase; - font-weight: normal; -} - -#bsa .native-box { - display: flex; - align-items: center; - padding: 10px; - margin: 16px 0; - border: 1px solid #e2e8f0; - border-radius: 4px; - background-color: #f8fafc; - text-decoration: none; - color: rgba(0, 0, 0, 0.68); -} - -#bsa .native-sponsor { - background-color: #447FD7; - color: #fff; - border-radius: 3px; - text-transform: uppercase; - padding: 5px 12px; - margin-right: 10px; - font-weight: 500; - font-size: 12px; -} diff --git a/docs/_static/sponsors.js b/docs/_static/sponsors.js deleted file mode 100644 index d6cd49f0e..000000000 --- a/docs/_static/sponsors.js +++ /dev/null @@ -1,42 +0,0 @@ -(function() { - function carbon() { - var h1 = document.querySelector('.t-body h1'); - if (!h1) return; - - var div = document.createElement('div'); - div.className = 'fund'; - h1.parentNode.insertBefore(div, h1.nextSibling); - - var s = document.createElement('script'); - s.async = 1; - s.id = '_carbonads_js'; - s.src = 'https://cdn.carbonads.com/carbon.js?serve=CE7DKK3W&placement=authliborg'; - div.appendChild(s); - } - - function bsa() { - var pagination = document.querySelector('.t-pagination'); - if (!pagination) return; - var div = document.createElement('div'); - div.id = 'bsa'; - pagination.parentNode.insertBefore(div, pagination); - - var s = document.createElement('script'); - s.async = 1; - s.src = 'https://m.servedby-buysellads.com/monetization.js'; - s.onload = function() { - if(typeof window._bsa !== 'undefined' && window._bsa) { - _bsa.init('custom', 'CE7DKK3M', 'placement:authliborg', { - target: '#bsa', - template: '
Sponsor
##company## - ##description##
' - }); - } - } - document.body.appendChild(s); - } - - document.addEventListener('DOMContentLoaded', function() { - carbon(); - setTimeout(bsa, 5000); - }); -})(); diff --git a/docs/_templates/funding.html b/docs/_templates/funding.html new file mode 100644 index 000000000..5de39b48b --- /dev/null +++ b/docs/_templates/funding.html @@ -0,0 +1,20 @@ + + diff --git a/docs/_templates/links.html b/docs/_templates/links.html index 92573c35b..e33566235 100644 --- a/docs/_templates/links.html +++ b/docs/_templates/links.html @@ -4,7 +4,7 @@

Useful Links

  • Homepage
  • Read Blog
  • Commercial License
  • -
  • Star on GitHub
  • +
  • Star on GitHub
  • Follow on Twitter
  • Help on StackOverflow
  • Loginpass
  • diff --git a/docs/_templates/partials/globaltoc-above.html b/docs/_templates/partials/globaltoc-above.html new file mode 100644 index 000000000..90143a770 --- /dev/null +++ b/docs/_templates/partials/globaltoc-above.html @@ -0,0 +1,7 @@ +
    + +
    +
    diff --git a/docs/_templates/sponsors.html b/docs/_templates/sponsors.html index fce1301a2..249fbf93a 100644 --- a/docs/_templates/sponsors.html +++ b/docs/_templates/sponsors.html @@ -8,6 +8,6 @@
    The new way to solve Identity. Sponsored by auth0.com
    diff --git a/docs/_templates/sustainable.html b/docs/_templates/sustainable.html deleted file mode 100644 index bb290215c..000000000 --- a/docs/_templates/sustainable.html +++ /dev/null @@ -1,13 +0,0 @@ -
    - - -
    diff --git a/docs/basic/install.rst b/docs/basic/install.rst index e65f0af77..543e83457 100644 --- a/docs/basic/install.rst +++ b/docs/basic/install.rst @@ -46,30 +46,24 @@ Using Authlib with Starlette:: $ pip install Authlib httpx Starlette -.. versionchanged:: v0.12 - - "requests" is an optional dependency since v0.12. If you want to use - Authlib client, you have to install "requests" by yourself:: - - $ pip install Authlib requests Get the Source Code ------------------- Authlib is actively developed on GitHub, where the code is -`always available `_. +`always available `_. You can either clone the public repository:: - $ git clone git://github.com/lepture/authlib.git + $ git clone git://github.com/authlib/authlib.git -Download the `tarball `_:: +Download the `tarball `_:: - $ curl -OL https://github.com/lepture/authlib/tarball/master + $ curl -OL https://github.com/authlib/authlib/tarball/main -Or, download the `zipball `_:: +Or, download the `zipball `_:: - $ curl -OL https://github.com/lepture/authlib/zipball/master + $ curl -OL https://github.com/authlib/authlib/zipball/main Once you have a copy of the source, you can embed it in your Python package, diff --git a/docs/basic/intro.rst b/docs/basic/intro.rst index de338515e..f8cb6c116 100644 --- a/docs/basic/intro.rst +++ b/docs/basic/intro.rst @@ -14,10 +14,8 @@ OAuth 1.0, OAuth 2.0, JWT and many more. It becomes a :ref:`monolithic` project that powers from low-level specification implementation to high-level framework integrations. -I'm intended to make it profitable so that it can be :ref:`sustainable`. - -.. raw:: html - :file: ../_templates/sustainable.html +I'm intended to make it profitable so that it can be :ref:`sustainable`, check +out the :ref:`funding` section. .. _monolithic: diff --git a/docs/changelog.rst b/docs/changelog.rst deleted file mode 100644 index be70ae3e9..000000000 --- a/docs/changelog.rst +++ /dev/null @@ -1,172 +0,0 @@ -Changelog -========= - -.. meta:: - :description: The full list of changes between each Authlib release. - -Here you can see the full list of changes between each Authlib release. - - -Version 0.14.3 --------------- - -**Released on May 18, 2020.** - -- Fix HTTPX integration via :gh:`PR#232` and :gh:`PR#233`. -- Add "bearer" as default token type for OAuth 2 Client. -- JWS and JWE don't validate private headers by default. -- Remove ``none`` auth method for authorization code by default. -- Allow usage of user provided ``code_verifier`` via :gh:`issue#216`. -- Add ``introspect_token`` method on OAuth 2 Client via :gh:`issue#224`. - - -Version 0.14.2 --------------- - -**Released on May 6, 2020.** - -- Fix OAuth 1.0 client for starlette. -- Allow leeway option in client parse ID token via :gh:`PR#228`. -- Fix OAuthToken when ``expires_at`` or ``expires_in`` is 0 via :gh:`PR#227`. -- Fix auto refresh token logic. -- Load server metadata before request. - - -Version 0.14.1 --------------- - -**Released on Feb 12, 2020.** - -- Quick fix for legacy imports of Flask and Django clients - - -Version 0.14 ------------- - -**Released on Feb 11, 2020.** - -In this release, Authlib has introduced a new way to write framework integrations -for clients. - -**Bug fixes** and enhancements in this release: - -- Fix HTTPX integrations due to HTTPX breaking changes -- Fix ES algorithms for JWS -- Allow user given ``nonce`` via :gh:`issue#180`. -- Fix OAuth errors ``get_headers`` leak. -- Fix ``code_verifier`` via :gh:`issue#165`. - -**Breaking Change**: drop sync OAuth clients of HTTPX. - - -Version 0.13 ------------- - -**Released on Nov 11, 2019. Go Async** - -This is the release that makes Authlib one more step close to v1.0. We -did a huge refactor on our integrations. Authlib believes in monolithic -design, it enables us to design the API to integrate with every framework -in the best way. In this release, Authlib has re-organized the folder -structure, moving every integration into the ``integrations`` folder. It -makes Authlib to add more integrations easily in the future. - -**RFC implementations** and updates in this release: - -- RFC7591: OAuth 2.0 Dynamic Client Registration Protocol -- RFC8628: OAuth 2.0 Device Authorization Grant - -**New integrations** and changes in this release: - -- **HTTPX** OAuth 1.0 and OAuth 2.0 clients in both sync and async way -- **Starlette** OAuth 1.0 and OAuth 2.0 client registry -- The experimental ``authlib.client.aiohttp`` has been removed - -**Bug fixes** and enhancements in this release: - -- Add custom client authentication methods for framework integrations. -- Refresh token automatically for client_credentials grant type. -- Enhancements on JOSE, specifying ``alg`` values easily for JWS and JWE. -- Add PKCE into requests OAuth2Session and HTTPX OAuth2Client. - -**Deprecate Changes**: find how to solve the deprecate issues via https://git.io/Jeclj - -Version 0.12 ------------- - -**Released on Sep 3, 2019.** - -**Breaking Change**: Authlib Grant system has been redesigned. If you -are creating OpenID Connect providers, please read the new documentation -for OpenID Connect. - -**Important Update**: Django OAuth 2.0 server integration is ready now. -You can create OAuth 2.0 provider and OpenID Connect 1.0 with Django -framework. - -RFC implementations and updates in this release: - -- RFC6749: Fixed scope validation, omit the invalid scope -- RFC7521: Added a common ``AssertionClient`` for the assertion framework -- RFC7662: Added ``IntrospectionToken`` for introspection token endpoint -- OpenID Connect Discover: Added discovery model based on RFC8414 - -Refactor and bug fixes in this release: - -- **Breaking Change**: add ``RefreshTokenGrant.revoke_old_credential`` method -- Rewrite lots of code for ``authlib.client``, no breaking changes -- Refactor ``OAuth2Request``, use explicit query and form -- Change ``requests`` to optional dependency -- Add ``AsyncAssertionClient`` for aiohttp - -**Deprecate Changes**: find how to solve the deprecate issues via https://git.io/fjPsV - -Version 0.11 ------------- - -**Released on Apr 6, 2019.** - -**BIG NEWS**: Authlib has changed its open source license **from AGPL to BSD**. - -**Important Changes**: Authlib specs module has been split into jose, oauth1, -oauth2, and oidc. Find how to solve the deprecate issues via https://git.io/fjvpt - -RFC implementations and updates in this release: - -- RFC7518: Added A128GCMKW, A192GCMKW, A256GCMKW algorithms for JWE. -- RFC5849: Removed draft-eaton-oauth-bodyhash-00 spec for OAuth 1.0. - -Small changes and bug fixes in this release: - -- Fixed missing scope on password and client_credentials grant types - of ``OAuth2Session`` via :gh:`issue#96`. -- Fixed Flask OAuth client cache detection via :gh:`issue#98`. -- Enabled ssl certificates for ``OAuth2Session`` via :gh:`PR#100`, thanks - to pingz. -- Fixed error response for invalid/expired refresh token via :gh:`issue#112`. -- Fixed error handle for invalid redirect uri via :gh:`issue#113`. -- Fixed error response redirect to fragment via :gh:`issue#114`. -- Fixed non-compliant responses from RFC7009 via :gh:`issue#119`. - -**Experiment Features**: There is an experiment ``aiohttp`` client for OAuth1 -and OAuth2 in ``authlib.client.aiohttp``. - - -Old Versions ------------- - -Find old changelog at https://github.com/lepture/authlib/releases - -- Version 0.10.0: Released on Oct 12, 2018 -- Version 0.9.0: Released on Aug 12, 2018 -- Version 0.8.0: Released on Jun 17, 2018 -- Version 0.7.0: Released on Apr 28, 2018 -- Version 0.6.0: Released on Mar 20, 2018 -- Version 0.5.1: Released on Feb 11, 2018 -- Version 0.5.0: Released on Feb 11, 2018 -- Version 0.4.1: Released on Feb 2, 2018 -- Version 0.4.0: Released on Jan 31, 2018 -- Version 0.3.0: Released on Dec 24, 2017 -- Version 0.2.1: Released on Dec 6, 2017 -- Version 0.2.0: Released on Nov 25, 2017 -- Version 0.1.0: Released on Nov 18, 2017 diff --git a/docs/client/api.rst b/docs/client/api.rst deleted file mode 100644 index 98765b0a9..000000000 --- a/docs/client/api.rst +++ /dev/null @@ -1,130 +0,0 @@ -Client API References -===================== - -.. meta:: - :description: API references on Authlib Client and its related Flask/Django integrations. - -This part of the documentation covers the interface of Authlib Client. - -Requests OAuth Sessions ------------------------ - -.. module:: authlib.integrations.requests_client - -.. autoclass:: OAuth1Session - :members: - create_authorization_url, - fetch_request_token, - fetch_access_token, - parse_authorization_response - -.. autoclass:: OAuth1Auth - :members: - -.. autoclass:: OAuth2Session - :members: - register_client_auth_method, - create_authorization_url, - fetch_token, - refresh_token, - revoke_token, - register_compliance_hook - -.. autoclass:: OAuth2Auth - -.. autoclass:: AssertionSession - - -HTTPX OAuth Clients -------------------- - -.. module:: authlib.integrations.httpx_client - -.. autoclass:: OAuth1Auth - :members: - - -.. autoclass:: AsyncOAuth1Client - :members: - create_authorization_url, - fetch_request_token, - fetch_access_token, - parse_authorization_response - -.. autoclass:: OAuth2Auth - -.. autoclass:: AsyncOAuth2Client - :members: - register_client_auth_method, - create_authorization_url, - fetch_token, - refresh_token, - revoke_token, - register_compliance_hook - -.. autoclass:: AsyncAssertionClient - - -Flask Registry and RemoteApp ----------------------------- - -.. module:: authlib.integrations.flask_client - -.. autoclass:: OAuth - :members: - init_app, - register, - create_client - -.. autoclass:: FlaskRemoteApp - :members: - authorize_redirect, - authorize_access_token, - save_authorize_data, - get, - post, - patch, - put, - delete - -Django Registry and RemoteApp ------------------------------ - -.. module:: authlib.integrations.django_client - -.. autoclass:: OAuth - :members: - register, - create_client - -.. autoclass:: DjangoRemoteApp - :members: - authorize_redirect, - authorize_access_token, - save_authorize_data, - get, - post, - patch, - put, - delete - -Starlette Registry and RemoteApp --------------------------------- - -.. module:: authlib.integrations.starlette_client - -.. autoclass:: OAuth - :members: - register, - create_client - -.. autoclass:: StarletteRemoteApp - :members: - authorize_redirect, - authorize_access_token, - save_authorize_data, - get, - post, - patch, - put, - delete diff --git a/docs/client/django.rst b/docs/client/django.rst deleted file mode 100644 index 115e3d46f..000000000 --- a/docs/client/django.rst +++ /dev/null @@ -1,150 +0,0 @@ -.. _django_client: - -Django OAuth Client -=================== - -.. meta:: - :description: The built-in Django integrations for OAuth 1.0 and - OAuth 2.0 clients, powered by Authlib. - -.. module:: authlib.integrations.django_client - :noindex: - -Looking for OAuth providers? - -- :ref:`django_oauth1_server` -- :ref:`django_oauth2_server` - -The Django client can handle OAuth 1 and OAuth 2 services. Authlib has -a shared API design among framework integrations. Get started with -:ref:`frameworks_clients`. - -Create a registry with :class:`OAuth` object:: - - from authlib.integrations.django_client import OAuth - - oauth = OAuth() - -The common use case for OAuth is authentication, e.g. let your users log in -with Twitter, GitHub, Google etc. - -.. note:: - - Please read :ref:`frameworks_clients` at first. Authlib has a shared API - design among framework integrations, learn them from :ref:`frameworks_clients`. - -.. versionchanged:: v0.13 - - Authlib moved all integrations into ``authlib.integrations`` module since v0.13. - For earlier version, developers can import the Django client with:: - - from authlib.django.client import OAuth - - -Configuration -------------- - -Authlib Django OAuth registry can load the configuration from your Django -application settings automatically. Every key value pair can be omit. -They can be configured from your Django settings:: - - AUTHLIB_OAUTH_CLIENTS = { - 'twitter': { - 'client_id': 'Twitter Consumer Key', - 'client_secret': 'Twitter Consumer Secret', - 'request_token_url': 'https://api.twitter.com/oauth/request_token', - 'request_token_params': None, - 'access_token_url': 'https://api.twitter.com/oauth/access_token', - 'access_token_params': None, - 'refresh_token_url': None, - 'authorize_url': 'https://api.twitter.com/oauth/authenticate', - 'api_base_url': 'https://api.twitter.com/1.1/', - 'client_kwargs': None - } - } - -We suggest that you keep ONLY ``client_id`` and ``client_secret`` in -your application settings, other parameters are better in ``.register()``. - -Saving Temporary Credential ---------------------------- - -In OAuth 1.0, we need to use a temporary credential to exchange access token, -this temporary credential was created before redirecting to the provider (Twitter), -we need to save this temporary credential somewhere in order to use it later. - -In OAuth 1, Django client will save the request token in sessions. In this -case, you just need to configure Session Middleware in Django:: - - MIDDLEWARE = [ - 'django.contrib.sessions.middleware.SessionMiddleware' - ] - -Follow the official Django documentation to set a proper session. Either a -database backend or a cache backend would work well. - -.. warning:: - - Be aware, using secure cookie as session backend will expose your request - token. - -Routes for Authorization ------------------------- - -Just like the example in :ref:`frameworks_clients`, everything is the same. -But there is a hint to create ``redirect_uri`` with ``request`` in Django:: - - def login(request): - # build a full authorize callback uri - redirect_uri = request.build_absolute_uri('/authorize') - return oauth.twitter.authorize_redirect(request, redirect_uri) - - -Auto Update Token via Signal ----------------------------- - -Instead of define a ``update_token`` method and passing it into OAuth registry, -it is also possible to use signal to listen for token updating:: - - from django.dispatch import receiver - from authlib.integrations.django_client import token_update - - @receiver(token_update) - def on_token_update(sender, token, refresh_token=None, access_token=None): - if refresh_token: - item = OAuth2Token.find(name=name, refresh_token=refresh_token) - elif access_token: - item = OAuth2Token.find(name=name, access_token=access_token) - else: - return - - # update old token - item.access_token = token['access_token'] - item.refresh_token = token.get('refresh_token') - item.expires_at = token['expires_at'] - item.save() - - -Django OpenID Connect Client ----------------------------- - -An OpenID Connect client is no different than a normal OAuth 2.0 client. When -register with ``openid`` scope, the built-in Django OAuth client will handle -everything automatically:: - - oauth.register( - 'google', - ... - server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', - client_kwargs={'scope': 'openid profile email'} - ) - -When we get the returned token:: - - token = oauth.google.authorize_access_token(request) - -We can get the user information from the ``id_token`` in the returned token:: - - userinfo = oauth.google.parse_id_token(request, token) - -Find Django Google login example at https://github.com/authlib/demo-oauth-client/tree/master/django-google-login diff --git a/docs/client/index.rst b/docs/client/index.rst deleted file mode 100644 index 60d90436f..000000000 --- a/docs/client/index.rst +++ /dev/null @@ -1,69 +0,0 @@ -OAuth Clients -============= - -.. meta:: - :description: This documentation contains Python OAuth 1.0 and OAuth 2.0 Clients - implementation with requests, HTTPX, Flask, Django and Starlette. - -This part of the documentation contains information on the client parts. Authlib -provides many frameworks integrations, including: - -* The famous Python Requests_ -* A next generation HTTP client for Python: httpx_ -* Flask_ web framework integration -* Django_ web framework integration -* Starlette_ web framework integration -* FastAPI_ web framework integration - -In order to use Authlib client, you have to install each library yourself. For -example, you want to use ``requests`` OAuth clients:: - - $ pip install Authlib requests - -For instance, you want to use ``httpx`` OAuth clients:: - - $ pip install -U Authlib httpx - -Here is a simple overview of Flask OAuth client:: - - from flask import Flask, jsonify - from authlib.integrations.flask_client import OAuth - - app = Flask(__name__) - oauth = OAuth(app) - github = oauth.register('github', {...}) - - @app.route('/login') - def login(): - redirect_uri = url_for('authorize', _external=True) - return github.authorize_redirect(redirect_uri) - - @app.route('/authorize') - def authorize(): - token = github.authorize_access_token() - # you can save the token into database - profile = github.get('/user', token=token) - return jsonify(profile) - -Follow the documentation below to find out more in detail. - -.. toctree:: - :maxdepth: 2 - - oauth1 - oauth2 - requests - httpx - frameworks - flask - django - starlette - fastapi - api - -.. _Requests: https://requests.readthedocs.io/en/master/ -.. _httpx: https://www.encode.io/httpx/ -.. _Flask: https://flask.palletsprojects.com -.. _Django: https://djangoproject.com -.. _Starlette: https://starlette.io -.. _FastAPI: https://fastapi.tiangolo.com/ diff --git a/docs/client/starlette.rst b/docs/client/starlette.rst deleted file mode 100644 index 32e5a58e2..000000000 --- a/docs/client/starlette.rst +++ /dev/null @@ -1,141 +0,0 @@ -.. _starlette_client: - -Starlette OAuth Client -====================== - -.. meta:: - :description: The built-in Starlette integrations for OAuth 1.0, OAuth 2.0 - and OpenID Connect clients, powered by Authlib. - -.. module:: authlib.integrations.starlette_client - :noindex: - -Starlette_ is a lightweight ASGI framework/toolkit, which is ideal for -building high performance asyncio services. - -.. _Starlette: https://www.starlette.io/ - -This documentation covers OAuth 1.0, OAuth 2.0 and OpenID Connect Client -support for Starlette. Because all the frameworks integrations share the -same API, it is best to: - -Read :ref:`frameworks_clients` at first. - -The difference between Starlette and Flask/Django integrations is Starlette -is **async**. We will use ``await`` for the functions we need to call. But -first, let's create an :class:`OAuth` instance:: - - from authlib.integrations.starlette_client import OAuth - - oauth = OAuth() - -The common use case for OAuth is authentication, e.g. let your users log in -with Twitter, GitHub, Google etc. - -Configuration -------------- - -Starlette can load configuration from environment; Authlib implementation -for Starlette client can use this configuration. Here is an example of how -to do it:: - - from starlette.config import Config - - config = Config('.env') - oauth = OAuth(config) - -Authlib will load ``client_id`` and ``client_secret`` from the configuration, -take google as an example:: - - oauth.register(name='google', ...) - -It will load **GOOGLE_CLIENT_ID** and **GOOGLE_CLIENT_SECRET** from the -environment. - -Register Remote Apps --------------------- - -``oauth.register`` is the same as :ref:`frameworks_clients`:: - - oauth.register( - 'google', - client_id='...', - client_secret='...', - ... - ) - -However, unlike Flask/Django, Starlette OAuth registry is using HTTPX -:class:`~authlib.integrations.httpx_client.AsyncOAuth1Client` and -:class:`~authlib.integrations.httpx_client.AsyncOAuth2Client` as the OAuth -backends. While Flask and Django are using the Requests version of -:class:`~authlib.integrations.requests_client.OAuth1Session` and -:class:`~authlib.integrations.requests_client.OAuth2Session`. - - -Enable Session for OAuth 1.0 ----------------------------- - -With OAuth 1.0, we need to use a temporary credential to exchange for an access token. -This temporary credential is created before redirecting to the provider (Twitter), -and needs to be saved somewhere in order to use it later. - -With OAuth 1, the Starlette client will save the request token in sessions. To -enable this, we need to add the ``SessionMiddleware`` middleware to the -application, which requires the installation of the ``itsdangerous`` package:: - - from starlette.applications import Starlette - from starlette.middleware.sessions import SessionMiddleware - - app = Starlette() - app.add_middleware(SessionMiddleware, secret_key="some-random-string") - -However, using the ``SessionMiddleware`` will store the temporary credential as -a secure cookie which will expose your request token to the client. - -Routes for Authorization ------------------------- - -Just like the examples in :ref:`frameworks_clients`, but Starlette is **async**, -the routes for authorization should look like:: - - @app.route('/login') - async def login(request): - google = oauth.create_client('google') - redirect_uri = request.url_for('authorize') - return await google.authorize_redirect(request, redirect_uri) - - @app.route('/auth') - async def authorize(request): - google = oauth.create_client('google') - token = await google.authorize_access_token(request) - user = await google.parse_id_token(request, token) - # do something with the token and profile - return '...' - -Starlette OpenID Connect ------------------------- - -An OpenID Connect client is no different than a normal OAuth 2.0 client, just add -``openid`` scope when ``.register``. In the above example, in ``authorize``:: - - user = await google.parse_id_token(request, token) - -There is a ``id_token`` in the response ``token``. We can parse userinfo from this -``id_token``. - -Here is how you can add ``openid`` scope in ``.register``:: - - oauth.register( - 'google', - ... - server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', - client_kwargs={'scope': 'openid profile email'} - ) - -Examples --------- - -We have Starlette demos at https://github.com/authlib/demo-oauth-client - -1. OAuth 1.0: `Starlette Twitter login `_ -2. OAuth 2.0: `Starlette Google login `_ diff --git a/docs/community/authors.rst b/docs/community/authors.rst index 61bc011b7..aea944e18 100644 --- a/docs/community/authors.rst +++ b/docs/community/authors.rst @@ -7,27 +7,43 @@ Authlib is written and maintained by `Hsiaoming Yang `_. Contributors ------------ -Here is the full list of the main contributors: +Here is the list of the main contributors: -https://github.com/lepture/authlib/graphs/contributors +- Ber Zoidberg +- Tom Christie +- Grey Li +- Pablo Marti +- Mario Jimenez Carrasco +- Bastian Venthur +- Nuno Santos +- Éloi Rivard +And more on https://github.com/authlib/authlib/graphs/contributors Sponsors -------- -Become a sponsor via `GitHub Sponsors`_ or Patreon_: +Become a sponsor via `GitHub Sponsors`_ or Patreon_ to support Authlib. + +Here is a full list of our sponsors, including past sponsors: * `Auth0 `_ +* `Authing `_ +Find out the :ref:`benefits for sponsorship `. Backers ------- -Become a backer `GitHub Sponsors`_ or via Patreon_: +Become a backer `GitHub Sponsors`_ or via Patreon_ to support Authlib. + +Here is a full list of our backers: * `Evilham `_ * `Aveline `_ * `Callam `_ +* `Krishna Kumar `_ +* `Yaal Coop `_ .. _`GitHub Sponsors`: https://github.com/sponsors/lepture .. _Patreon: https://www.patreon.com/lepture diff --git a/docs/community/awesome.rst b/docs/community/awesome.rst index 6e81cd698..499f704b3 100644 --- a/docs/community/awesome.rst +++ b/docs/community/awesome.rst @@ -73,3 +73,5 @@ Articles - `Using Authlib with gspread `_. - `Multipart Upload to Google Cloud Storage `_. - `Create Twitter login for FastAPI `_. +- `Google login for FastAPI `_. +- `FastAPI with Google OAuth `_. diff --git a/docs/community/contribute.rst b/docs/community/contribute.rst index d0e16c116..6635cae64 100644 --- a/docs/community/contribute.rst +++ b/docs/community/contribute.rst @@ -41,7 +41,7 @@ Thank you. Now that you have a fix for Authlib, please describe it clearly in your pull request. There are some requirements for a pull request to be accepted: -* Follow PEP8 code style. You can use flake8 to check your code style. +* You can use ruff to check your code style. * Tests for the code changes are required. * Please add documentation for it, if it requires. @@ -61,7 +61,8 @@ Finance support is also welcome. A better finance can make Authlib listed in the Authlib GitHub repository, or have your company logo placed on this website. - `Become a backer or sponsor via Patreon `_ + * `Become a backer or sponsor via Patreon `_ + * `Become a backer or sponsor via GitHub `_ 2. **One Time Donation** diff --git a/docs/community/funding.rst b/docs/community/funding.rst new file mode 100644 index 000000000..8fff3141b --- /dev/null +++ b/docs/community/funding.rst @@ -0,0 +1,89 @@ +.. _funding: + +Funding +======= + +If you use Authlib and its related projects commercially we strongly +encourage you to invest in its **sustainable** development by sponsorship. + +We accept funding with paid license and sponsorship. With the funding, it +will: + +* contribute to faster releases, more features, and higher quality software. +* allow more time to be invested in the documentation, issues, and community support. + +And you can also get benefits from us: + +1. access to some of our private repositories +2. access to our `private PyPI `_. +3. join our security mail list. + +Get more details on our sponsor tiers page at: + +1. GitHub sponsors: https://github.com/sponsors/lepture +2. Patreon: https://www.patreon.com/lepture + +Insiders +-------- + +Insiders are people who have access to our private repositories, you can become +an insider with: + +1. Purchasing a paid license at https://authlib.org/plans +2. Become a sponsor with tiers including "Access to our private repos" benefit + +PyPI +---- + +We offer a private PyPI server to release early security fixes and features. +You can find more details about this PyPI server at: + +https://authlib.org/pypi + +Goals +----- + +The following list of funding goals shows features and additional addons +we are going to add. + +Funding Goal: $500/month +~~~~~~~~~~~~~~~~~~~~~~~~ + +* :bdg-success:`done` setup a private PyPI +* :bdg-warning:`todo` A running demo of loginpass services +* :bdg-warning:`todo` Starlette integration of loginpass + + +Funding Goal: $2000/month +~~~~~~~~~~~~~~~~~~~~~~~~~ + +* :bdg-warning:`todo` A simple running demo of OIDC provider in Flask + +When the demo is complete, source code of the demo will only be available to our insiders. + +Funding Goal: $5000/month +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In Authlib v2.0, we will start working on async provider integrations. + +* :bdg-warning:`todo` Starlette (FastAPI) OAuth 1.0 provider integration +* :bdg-warning:`todo` Starlette (FastAPI) OAuth 2.0 provider integration +* :bdg-warning:`todo` Starlette (FastAPI) OIDC provider integration + +Funding Goal: $9000/month +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In Authlib v3.0, we will add built-in support for SAML. + +* :bdg-warning:`todo` SAML 2.0 implementation +* :bdg-warning:`todo` RFC7522 (SAML) 2.0 Profile for OAuth 2.0 Client Authentication and Authorization Grants +* :bdg-warning:`todo` CBOR Object Signing and Encryption +* :bdg-warning:`todo` A complex running demo of OIDC provider + +Our Sponsors +------------ + +Here is our current sponsors, we keep a full list of our sponsors in the Authors page. + +.. raw:: html + :file: ../_templates/funding.html diff --git a/docs/community/index.rst b/docs/community/index.rst index fe1d91302..7952015ef 100644 --- a/docs/community/index.rst +++ b/docs/community/index.rst @@ -8,6 +8,7 @@ issues and finance. .. toctree:: :maxdepth: 2 + funding support security contribute diff --git a/docs/community/licenses.rst b/docs/community/licenses.rst index feb341a8a..8a84bd5b7 100644 --- a/docs/community/licenses.rst +++ b/docs/community/licenses.rst @@ -1,8 +1,15 @@ Authlib Licenses ================ -Authlib offers two licenses, one is BSD for open source projects, one is -a commercial license for closed source projects. +Authlib offers two licenses: + +1. BSD LICENSE +2. COMMERCIAL-LICENSE + +Any project, open or closed source, can use the BSD license. +If your company needs commercial support, you can purchase a commercial license at +`Authlib Plans `_. You can find more information at +https://authlib.org/support. Open Source License ------------------- diff --git a/docs/community/security.rst b/docs/community/security.rst index 3c1dda775..cd84764aa 100644 --- a/docs/community/security.rst +++ b/docs/community/security.rst @@ -28,4 +28,5 @@ Here is the process when we have received a security report: Previous CVEs ------------- -.. note:: No CVEs yet +- CVE-2022-39174 +- CVE-2022-39175 diff --git a/docs/community/support.rst b/docs/community/support.rst index 89e9dd8fd..e6515a1ce 100644 --- a/docs/community/support.rst +++ b/docs/community/support.rst @@ -30,7 +30,7 @@ Feature Requests If you have feature requests, please comment on `Features Checklist`_. If they are accepted, they will be listed in the post. -.. _`Features Checklist`: https://github.com/lepture/authlib/issues/1 +.. _`Features Checklist`: https://github.com/authlib/authlib/issues/1 Commercial Support diff --git a/docs/community/sustainable.rst b/docs/community/sustainable.rst index 077d94951..47ac10f6f 100644 --- a/docs/community/sustainable.rst +++ b/docs/community/sustainable.rst @@ -6,9 +6,6 @@ Sustainable A sustainable project is trustworthy to use in your production environment. To make this project sustainable, we need your help. Here are several options: -.. raw:: html - :file: ../_templates/sustainable.html - Community Contribute -------------------- @@ -29,11 +26,15 @@ You are welcome to become a backer or a sponsor. .. _`GitHub Sponsors`: https://github.com/sponsors/lepture .. _Patreon: https://www.patreon.com/lepture +Find out the :ref:`benefits for sponsorship `. + Commercial License ------------------ -Authlib is licensed under BSD for open source projects. If you are -running a business, consider to purchase a commercial license instead. +Authlib is licensed under BSD-3 for any project. +If you are running a business, and you need advanced support, +and wish to help Authlib sustainability, +please consider to purchase a commercial license instead. Find more information on https://authlib.org/support#commercial-license diff --git a/docs/conf.py b/docs/conf.py index 70cd76f29..7e421019d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,138 +1,81 @@ -import os -import sys -sys.path.insert(0, os.path.abspath('..')) +import warnings import authlib -import sphinx_typlog_theme +from authlib.deprecate import AuthlibDeprecationWarning -extensions = ['sphinx.ext.autodoc'] -templates_path = ['_templates'] +# we will keep authlib.jose module until 2.0.0 +warnings.simplefilter("ignore", AuthlibDeprecationWarning) -source_suffix = '.rst' -master_doc = 'index' - -project = u'Authlib' -copyright = u'2017, Hsiaoming Ltd' -author = u'Hsiaoming Yang' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. +project = "Authlib" +copyright = "© 2017, Hsiaoming Ltd" +author = "Hsiaoming Yang" version = authlib.__version__ -# The full version, including alpha/beta/rc tags. release = version -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = 'en' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -html_theme = 'sphinx_typlog_theme' -html_favicon = '_static/favicon.ico' -html_theme_path = [sphinx_typlog_theme.get_path()] -html_theme_options = { - 'logo': 'authlib.svg', - 'color': '#3E7FCB', - 'description': ( - 'The ultimate Python library in building OAuth and OpenID Connect ' - 'servers. JWS, JWE, JWK, JWA, JWT are included.' - ), - 'github_user': 'lepture', - 'github_repo': 'authlib', - 'twitter': 'authlib', - 'og_image': 'https://authlib.org/logo.png', - 'meta_html': ( - '' - ) -} - -html_context = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -_sidebar_templates = [ - 'logo.html', - 'github.html', - 'sponsors.html', - 'globaltoc.html', - 'links.html', - 'searchbox.html', - 'tidelift.html', -] -if '.dev' in release: - version_warning = ( - 'This is the documentation of the development version, check the ' - 'Stable Version documentation.' - ) - html_theme_options['warning'] = version_warning - -html_sidebars = { - '**': _sidebar_templates -} - -# -- Options for HTMLHelp output ------------------------------------------ - -# Output file base name for HTML help builder. -htmlhelp_basename = 'Authlibdoc' - - -# -- Options for LaTeX output --------------------------------------------- - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'Authlib.tex', u'Authlib Documentation', - u'Hsiaoming Yang', 'manual'), -] - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'authlib', u'Authlib Documentation', [author], 1) +templates_path = ["_templates"] +html_static_path = ["_static"] +html_css_files = [ + "custom.css", ] +html_theme = "shibuya" +html_copy_source = False +html_show_sourcelink = False -# -- Options for Texinfo output ------------------------------------------- +language = "en" -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ( - master_doc, 'Authlib', u'Authlib Documentation', - author, 'Authlib', 'One line description of project.', - 'Miscellaneous' - ), +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.extlinks", + "sphinx.ext.intersphinx", + "sphinx_copybutton", + "sphinx_design", ] -html_css_files = [ - 'sponsors.css', -] -html_js_files = [ - 'sponsors.js', -] +extlinks = { + "issue": ("https://github.com/authlib/authlib/issues/%s", "issue #%s"), + "PR": ("https://github.com/authlib/authlib/pull/%s", "pull request #%s"), +} +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "joserfc": ("https://jose.authlib.org/en/", None), +} +html_favicon = "_static/icon.svg" +html_theme_options = { + "accent_color": "blue", + "globaltoc_expand_depth": 1, + "og_image_url": "https://authlib.org/logo.png", + "light_logo": "_static/light-logo.svg", + "dark_logo": "_static/dark-logo.svg", + "twitter_site": "authlib", + "twitter_creator": "lepture", + "twitter_url": "https://twitter.com/authlib", + "github_url": "https://github.com/authlib/authlib", + "discord_url": "https://discord.gg/HvBVAeNAaV", + "nav_links": [ + { + "title": "Projects", + "children": [ + { + "title": "Authlib", + "url": "https://authlib.org/", + "summary": "OAuth, JOSE, OpenID, etc.", + }, + { + "title": "JOSE RFC", + "url": "https://jose.authlib.org/", + "summary": "JWS, JWE, JWK, and JWT.", + }, + { + "title": "OTP Auth", + "url": "https://otp.authlib.org/", + "summary": "One time password, HOTP/TOTP.", + }, + ], + }, + {"title": "Sponsor me", "url": "https://github.com/sponsors/lepture"}, + ], +} -def setup(app): - sphinx_typlog_theme.add_badge_roles(app) - sphinx_typlog_theme.add_github_roles(app, 'lepture/authlib') +html_context = {} diff --git a/docs/django/index.rst b/docs/django/index.rst deleted file mode 100644 index c80ac2c4b..000000000 --- a/docs/django/index.rst +++ /dev/null @@ -1,12 +0,0 @@ -Django OAuth Providers -====================== - -Authlib has built-in Django integrations for building OAuth 1.0 and -OAuth 2.0 servers. It is best if developers can read -:ref:`intro_oauth1` and :ref:`intro_oauth2` at first. - -.. toctree:: - :maxdepth: 2 - - 1/index - 2/index diff --git a/docs/flask/1/api.rst b/docs/flask/1/api.rst deleted file mode 100644 index 71693c618..000000000 --- a/docs/flask/1/api.rst +++ /dev/null @@ -1,39 +0,0 @@ -API References of Flask OAuth 1.0 Server -======================================== - -This part of the documentation covers the interface of Flask OAuth 1.0 -Server. - -.. module:: authlib.integrations.flask_oauth1 - -.. autoclass:: AuthorizationServer - :members: - -.. autoclass:: ResourceProtector - :member-order: bysource - :members: - -.. data:: current_credential - - Routes protected by :class:`ResourceProtector` can access current credential - with this variable. - - -SQLAlchemy Help Functions -------------------------- - -.. warning:: We will drop ``sqla_oauth2`` module in version 1.0. - -.. module:: authlib.integrations.sqla_oauth1 - -.. autofunction:: create_query_client_func - -.. autofunction:: create_query_token_func - -.. autofunction:: create_exists_nonce_func - -.. autofunction:: register_nonce_hooks - -.. autofunction:: register_temporary_credential_hooks - -.. autofunction:: register_token_credential_hooks diff --git a/docs/flask/1/resource-server.rst b/docs/flask/1/resource-server.rst deleted file mode 100644 index 139dfad25..000000000 --- a/docs/flask/1/resource-server.rst +++ /dev/null @@ -1,64 +0,0 @@ -Resource Servers -================ - -Protect users resources, so that only the authorized clients with the -authorized access token can access the given scope resources. - -A resource server can be a different server other than the authorization -server. Here is the way to protect your users' resources:: - - from flask import jsonify - from authlib.integrations.flask_oauth1 import ResourceProtector, current_credential - from authlib.integrations.flask_oauth1 import create_exists_nonce_func - from authlib.integrations.sqla_oauth1 import ( - create_query_client_func, - create_query_token_func - ) - - query_client = create_query_client_func(db.session, Client) - query_token = create_query_token_func(db.session, TokenCredential) - exists_nonce = create_exists_nonce_func(cache) - # OR: authlib.integrations.sqla_oauth1.create_exists_nonce_func - - require_oauth = ResourceProtector( - app, query_client=query_client, - query_token=query_token, - exists_nonce=exists_nonce, - ) - # or initialize it lazily - require_oauth = ResourceProtector() - require_oauth.init_app( - app, - query_client=query_client, - query_token=query_token, - exists_nonce=exists_nonce, - ) - - @app.route('/user') - @require_oauth() - def user_profile(): - user = current_credential.user - return jsonify(user) - -The ``current_credential`` is a proxy to the Token model you have defined above. -Since there is a ``user`` relationship on the Token model, we can access this -``user`` with ``current_credential.user``. - - -MethodView & Flask-Restful --------------------------- - -You can also use the ``require_oauth`` decorator in ``flask.views.MethodView`` -and ``flask_restful.Resource``:: - - from flask.views import MethodView - - class UserAPI(MethodView): - decorators = [require_oauth()] - - - from flask_restful import Resource - - class UserAPI(Resource): - method_decorators = [require_oauth()] - diff --git a/docs/flask/index.rst b/docs/flask/index.rst deleted file mode 100644 index f778df63c..000000000 --- a/docs/flask/index.rst +++ /dev/null @@ -1,12 +0,0 @@ -Flask OAuth Providers -===================== - -Authlib has built-in Flask integrations for building OAuth 1.0, OAuth 2.0 and -OpenID Connect servers. It is best if developers can read :ref:`intro_oauth1` -and :ref:`intro_oauth2` at first. - -.. toctree:: - :maxdepth: 2 - - 1/index - 2/index diff --git a/docs/index.rst b/docs/index.rst index 96c82f2ee..da81e9a9c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,43 +9,18 @@ Authlib: Python Authentication Release v\ |version|. (:ref:`Installation `) -The ultimate Python library in building OAuth and OpenID Connect servers. +The ultimate Python library in building OAuth and OpenID Connect servers and clients. It is designed from low level specifications implementations to high level frameworks integrations, to meet the needs of everyone. -Authlib is compatible with Python2.7+ and Python3.6+. -(We will drop Python 2 support when Authlib 1.0 is released) - -User's Guide ------------- - -This part of the documentation begins with some background information -about Authlib, and installation of Authlib. Then it will explain OAuth 1.0, -OAuth 2.0, and JOSE. At last, it shows the implementation in frameworks, and -libraries such as Flask, Django, Requests, HTTPX, Starlette, FastAPI, and etc. +Authlib is compatible with Python3.10+. .. toctree:: :maxdepth: 2 basic/index - client/index + oauth2/index + oauth1/index jose/index - oauth/index - flask/index - django/index - specs/index community/index - - -Get Updates ------------ - -Stay tuned with Authlib, here is a history of Authlib changes. - -.. toctree:: - :maxdepth: 2 - - changelog - -Consider to follow `Authlib on Twitter `_, -and subscribe `Authlib Blog `_. + upgrades/index diff --git a/docs/jose/index.rst b/docs/jose/index.rst index 4335ba931..b59aa6733 100644 --- a/docs/jose/index.rst +++ b/docs/jose/index.rst @@ -1,7 +1,7 @@ .. _jose: -JOSE Guide -========== +JOSE +==== This part of the documentation contains information on the JOSE implementation. It includes: @@ -12,6 +12,15 @@ It includes: 4. JSON Web Algorithm (JWA) 5. JSON Web Token (JWT) +.. versionchanged:: 1.7 + We are deprecating ``authlib.jose`` module in favor of joserfc_. + It will be removed in Authlib 1.8. + +.. _joserfc: https://jose.authlib.org/en/ + +Usage +----- + A simple example on how to use JWT with Authlib:: from authlib.jose import jwt @@ -23,6 +32,9 @@ A simple example on how to use JWT with Authlib:: header = {'alg': 'RS256'} s = jwt.encode(header, payload, key) +Guide +----- + Follow the documentation below to find out more in detail. .. toctree:: @@ -32,3 +44,4 @@ Follow the documentation below to find out more in detail. jwe jwk jwt + specs/index diff --git a/docs/jose/jwe.rst b/docs/jose/jwe.rst index 9a771a9c5..58ca4f72a 100644 --- a/docs/jose/jwe.rst +++ b/docs/jose/jwe.rst @@ -9,6 +9,13 @@ JSON Web Encryption (JWE) JSON Web Encryption (JWE) represents encrypted content using JSON-based data structures. +.. important:: + + We are splitting the ``jose`` module into a separated package. You may be + interested in joserfc_. + +.. _joserfc: https://jose.authlib.org/en/dev/guide/jwe/ + There are two types of JWE Serializations: 1. JWE Compact Serialization diff --git a/docs/jose/jwk.rst b/docs/jose/jwk.rst index db6b98f75..d057ca671 100644 --- a/docs/jose/jwk.rst +++ b/docs/jose/jwk.rst @@ -3,10 +3,12 @@ JSON Web Key (JWK) ================== -.. versionchanged:: v0.15 +.. important:: - This documentation is updated for v0.15. Please check "stable" documentation for - Authlib v0.14. + We are splitting the ``jose`` module into a separated package. You may be + interested in joserfc_. + +.. _joserfc: https://jose.authlib.org/en/dev/guide/jwk/ .. module:: authlib.jose :noindex: diff --git a/docs/jose/jws.rst b/docs/jose/jws.rst index ff21a225d..fdd1fdd6e 100644 --- a/docs/jose/jws.rst +++ b/docs/jose/jws.rst @@ -10,6 +10,14 @@ JSON Web Signature (JWS) represents content secured with digital signatures or Message Authentication Codes (MACs) using JSON-based data structures. +.. important:: + + We are splitting the ``jose`` module into a separated package. You may be + interested in joserfc_. + +.. _joserfc: https://jose.authlib.org/en/dev/guide/jws/ + + There are two types of JWS Serializations: 1. JWS Compact Serialization @@ -93,8 +101,9 @@ algorithms: 1. HS256, HS384, HS512 2. RS256, RS384, RS512 -3. ES256, ES384, ES512 +3. ES256, ES384, ES512, ES256K 4. PS256, PS384, PS512 +5. EdDSA For example, a JWS with RS256 requires a private PEM key to sign the JWS:: @@ -115,6 +124,28 @@ To deserialize a JWS Compact Serialization, use jws_header = data['header'] payload = data['payload'] +.. important:: + + The above method is susceptible to a signature bypass described in CVE-2016-10555. + It allows mixing symmetric algorithms and asymmetric algorithms. You should never + combine symmetric (HS) and asymmetric (RS, ES, PS) signature schemes. + + If you must support both protocols use a custom key loader which provides a different + keys for different methods. + +Load a different ``key`` for symmetric and asymmetric signatures:: + + def load_key(header, payload): + if header['alg'] == 'RS256': + return rsa_pub_key + elif header['alg'] == 'HS256': + return shared_secret + else: + raise UnsupportedAlgorithmError() + + claims = jws.deserialize_compact(token, load_key) + + A ``key`` can be dynamically loaded, if you don't know which key to be used:: def load_key(header, payload): diff --git a/docs/jose/jwt.rst b/docs/jose/jwt.rst index f3cf9f450..56d615d85 100644 --- a/docs/jose/jwt.rst +++ b/docs/jose/jwt.rst @@ -3,6 +3,13 @@ JSON Web Token (JWT) ==================== +.. important:: + + We are splitting the ``jose`` module into a separated package. You may be + interested in joserfc_. + +.. _joserfc: https://jose.authlib.org/en/dev/guide/jwt/ + .. module:: authlib.jose :noindex: @@ -14,9 +21,10 @@ keys of :ref:`specs/rfc7517`:: >>> from authlib.jose import jwt >>> header = {'alg': 'RS256'} >>> payload = {'iss': 'Authlib', 'sub': '123', ...} - >>> key = read_file('private.pem') - >>> s = jwt.encode(header, payload, key) - >>> claims = jwt.decode(s, read_file('public.pem')) + >>> private_key = read_file('private.pem') + >>> s = jwt.encode(header, payload, private_key) + >>> public_key = read_file('public.pem') + >>> claims = jwt.decode(s, public_key) >>> print(claims) {'iss': 'Authlib', 'sub': '123', ...} >>> print(claims.header) @@ -47,8 +55,8 @@ payload with the given ``alg`` in header:: >>> from authlib.jose import jwt >>> header = {'alg': 'RS256'} >>> payload = {'iss': 'Authlib', 'sub': '123', ...} - >>> key = read_file('private.pem') - >>> s = jwt.encode(header, payload, key) + >>> private_key = read_file('private.pem') + >>> s = jwt.encode(header, payload, private_key) The available keys in headers are defined by :ref:`specs/rfc7515`. @@ -59,7 +67,16 @@ JWT Decode dict of the payload:: >>> from authlib.jose import jwt - >>> claims = jwt.decode(s, read_file('public.pem')) + >>> public_key = read_file('public.pem') + >>> claims = jwt.decode(s, public_key) + +.. important:: + + This decoding method is insecure. By default ``jwt.decode`` parses the alg header. + This allows symmetric macs and asymmetric signatures. If both are allowed a signature bypass described in CVE-2016-10555 is possible. + + See the following section for a mitigation. + The returned value is a :class:`JWTClaims`, check the next section to validate claims value. @@ -74,6 +91,28 @@ of supported ``alg`` into :class:`JsonWebToken`:: >>> from authlib.jose import JsonWebToken >>> jwt = JsonWebToken(['RS256']) +.. important:: + + You should never combine symmetric (HS) and asymmetric (RS, ES, PS) signature schemes. + When both are allowed a signature bypass described in CVE-2016-10555 is possible. + + If you must support both protocols use a custom key loader which provides a different + keys for different methods. + +Load a different ``key`` for symmetric and asymmetric signatures:: + + def load_key(header, payload): + if header['alg'] == 'RS256': + return rsa_pub_key + elif header['alg'] == 'HS256': + return shared_secret + else: + raise UnsupportedAlgorithmError() + + claims = jwt.decode(token, load_key) + + + JWT Payload Claims Validation ----------------------------- @@ -110,3 +149,35 @@ It is a dict configuration, the option key is the name of a claim. - **values**: claim value can be any one in the values list. - **value**: claim value MUST be the same value. - **validate**: a function to validate the claim value. + + +Use dynamic keys +---------------- + +When ``.encode`` and ``.decode`` a token, there is a ``key`` parameter to use. +This ``key`` can be the bytes of your PEM key, a JWK set, and a function. + +There are cases that you don't know which key to use to ``.decode`` the token. +For instance, you have a JWK set:: + + jwks = { + "keys": [ + { "kid": "k1", ...}, + { "kid": "k2", ...}, + ] + } + +And in the token, it has a ``kid=k2`` in the header part, if you pass ``jwks`` to +the ``key`` parameter, Authlib will auto resolve the correct key:: + + jwt.decode(s, key=jwks, ...) + +It is also possible to resolve the correct key by yourself:: + + def resolve_key(header, payload): + return my_keys[header['kid']] + + jwt.decode(s, key=resolve_key) + +For ``.encode``, if you pass a JWK set, it will randomly pick a key and assign its +``kid`` into the header. diff --git a/docs/jose/specs/index.rst b/docs/jose/specs/index.rst new file mode 100644 index 000000000..e0809a1d5 --- /dev/null +++ b/docs/jose/specs/index.rst @@ -0,0 +1,13 @@ +Specifications +============== + +.. toctree:: + :maxdepth: 1 + + rfc7515 + rfc7516 + rfc7517 + rfc7518 + rfc7519 + rfc7638 + rfc8037 diff --git a/docs/specs/rfc7515.rst b/docs/jose/specs/rfc7515.rst similarity index 100% rename from docs/specs/rfc7515.rst rename to docs/jose/specs/rfc7515.rst diff --git a/docs/specs/rfc7516.rst b/docs/jose/specs/rfc7516.rst similarity index 100% rename from docs/specs/rfc7516.rst rename to docs/jose/specs/rfc7516.rst diff --git a/docs/specs/rfc7517.rst b/docs/jose/specs/rfc7517.rst similarity index 100% rename from docs/specs/rfc7517.rst rename to docs/jose/specs/rfc7517.rst diff --git a/docs/specs/rfc7518.rst b/docs/jose/specs/rfc7518.rst similarity index 98% rename from docs/specs/rfc7518.rst rename to docs/jose/specs/rfc7518.rst index cd2304d3e..e9ebee357 100644 --- a/docs/specs/rfc7518.rst +++ b/docs/jose/specs/rfc7518.rst @@ -52,7 +52,7 @@ This section is defined by RFC7518 `Section 3.4`_. 1. ES256: ECDSA using P-256 and SHA-256 2. ES384: ECDSA using P-384 and SHA-384 -3. ES384: ECDSA using P-521 and SHA-512 +3. ES512: ECDSA using P-521 and SHA-512 Digital Signature with RSASSA-PSS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/specs/rfc7519.rst b/docs/jose/specs/rfc7519.rst similarity index 100% rename from docs/specs/rfc7519.rst rename to docs/jose/specs/rfc7519.rst diff --git a/docs/specs/rfc7638.rst b/docs/jose/specs/rfc7638.rst similarity index 100% rename from docs/specs/rfc7638.rst rename to docs/jose/specs/rfc7638.rst diff --git a/docs/specs/rfc8037.rst b/docs/jose/specs/rfc8037.rst similarity index 100% rename from docs/specs/rfc8037.rst rename to docs/jose/specs/rfc8037.rst diff --git a/docs/oauth/1/index.rst b/docs/oauth/1/index.rst deleted file mode 100644 index 894471a12..000000000 --- a/docs/oauth/1/index.rst +++ /dev/null @@ -1,13 +0,0 @@ -OAuth 1.0 -========= - -OAuth 1.0 is the standardization and combined wisdom of many well established industry protocols -at its creation time. It was first introduced as Twitter's open protocol. It is similar to other protocols -at that time in use (Google AuthSub, AOL OpenAuth, Yahoo BBAuth, Upcoming API, Flickr API, etc). - -If you are creating an open platform, AUTHLIB ENCOURAGE YOU USE OAUTH 2.0 INSTEAD. - -.. toctree:: - :maxdepth: 2 - - intro diff --git a/docs/oauth/2/index.rst b/docs/oauth/2/index.rst deleted file mode 100644 index b6c09eceb..000000000 --- a/docs/oauth/2/index.rst +++ /dev/null @@ -1,7 +0,0 @@ -OAuth 2.0 -========= - -.. toctree:: - :maxdepth: 2 - - intro diff --git a/docs/oauth/index.rst b/docs/oauth/index.rst deleted file mode 100644 index ce8f35a9f..000000000 --- a/docs/oauth/index.rst +++ /dev/null @@ -1,12 +0,0 @@ -OAuth & OpenID Connect -====================== - -This section contains introduction and implementation of Authlib core -OAuth 1.0, OAuth 2.0, and OpenID Connect. - -.. toctree:: - :maxdepth: 2 - - 1/index - 2/index - oidc/index diff --git a/docs/oauth/oidc/core.rst b/docs/oauth/oidc/core.rst deleted file mode 100644 index d0a89931b..000000000 --- a/docs/oauth/oidc/core.rst +++ /dev/null @@ -1,23 +0,0 @@ -OpenID Connect Core -=================== - -This section is about the core part of OpenID Connect. Authlib implemented -`OpenID Connect Core 1.0`_ on top of OAuth 2.0. It enhanced OAuth 2.0 with: - -.. module:: authlib.oidc.core.grants - :noindex: - -1. :class:`OpenIDCode` extension for Authorization code flow -2. :class:`OpenIDImplicitGrant` grant type for implicit flow -3. :class:`OpenIDHybridGrant` grant type for hybrid flow - -.. _`OpenID Connect Core 1.0`: https://openid.net/specs/openid-connect-core-1_0.html - -Authorization Code Flow ------------------------ - -Implicit Flow -------------- - -Hybrid Flow ------------ diff --git a/docs/oauth/oidc/discovery.rst b/docs/oauth/oidc/discovery.rst deleted file mode 100644 index 361bc57bc..000000000 --- a/docs/oauth/oidc/discovery.rst +++ /dev/null @@ -1,77 +0,0 @@ -OpenID Connect Discovery -======================== - -This section is about OpenID Provider Discovery. OpenID Providers have metadata describing their configuration. -The endpoint is usually located at:: - - /.well-known/openid-configuration - -The metadata is formatted in JSON. Here is an example of how it looks like: - -.. code-block:: http - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "issuer": - "https://server.example.com", - "authorization_endpoint": - "https://server.example.com/connect/authorize", - "token_endpoint": - "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": - ["client_secret_basic", "private_key_jwt"], - "token_endpoint_auth_signing_alg_values_supported": - ["RS256", "ES256"], - "userinfo_endpoint": - "https://server.example.com/connect/userinfo", - "check_session_iframe": - "https://server.example.com/connect/check_session", - "end_session_endpoint": - "https://server.example.com/connect/end_session", - "jwks_uri": - "https://server.example.com/jwks.json", - "registration_endpoint": - "https://server.example.com/connect/register", - "scopes_supported": - ["openid", "profile", "email", "address", - "phone", "offline_access"], - "response_types_supported": - ["code", "code id_token", "id_token", "token id_token"], - "acr_values_supported": - ["urn:mace:incommon:iap:silver", - "urn:mace:incommon:iap:bronze"], - "subject_types_supported": - ["public", "pairwise"], - "userinfo_signing_alg_values_supported": - ["RS256", "ES256", "HS256"], - "userinfo_encryption_alg_values_supported": - ["RSA1_5", "A128KW"], - "userinfo_encryption_enc_values_supported": - ["A128CBC-HS256", "A128GCM"], - "id_token_signing_alg_values_supported": - ["RS256", "ES256", "HS256"], - "id_token_encryption_alg_values_supported": - ["RSA1_5", "A128KW"], - "id_token_encryption_enc_values_supported": - ["A128CBC-HS256", "A128GCM"], - "request_object_signing_alg_values_supported": - ["none", "RS256", "ES256"], - "display_values_supported": - ["page", "popup"], - "claim_types_supported": - ["normal", "distributed"], - "claims_supported": - ["sub", "iss", "auth_time", "acr", - "name", "given_name", "family_name", "nickname", - "profile", "picture", "website", - "email", "email_verified", "locale", "zoneinfo", - "http://example.info/claims/groups"], - "claims_parameter_supported": - true, - "service_documentation": - "http://server.example.com/connect/service_documentation.html", - "ui_locales_supported": - ["en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"] - } diff --git a/docs/oauth/oidc/index.rst b/docs/oauth/oidc/index.rst deleted file mode 100644 index 7df8f554c..000000000 --- a/docs/oauth/oidc/index.rst +++ /dev/null @@ -1,9 +0,0 @@ -OpenID Connect -============== - -.. toctree:: - :maxdepth: 2 - - intro - core - discovery diff --git a/docs/oauth/oidc/intro.rst b/docs/oauth/oidc/intro.rst deleted file mode 100644 index 9215b4317..000000000 --- a/docs/oauth/oidc/intro.rst +++ /dev/null @@ -1,6 +0,0 @@ -Introduce OpenID Connect -======================== - -OpenID Connect is an identity layer on top of the OAuth 2.0 framework. - -(TBD) diff --git a/docs/oauth1/client/http/api.rst b/docs/oauth1/client/http/api.rst new file mode 100644 index 000000000..ffd850d93 --- /dev/null +++ b/docs/oauth1/client/http/api.rst @@ -0,0 +1,43 @@ +Reference +========= + +.. meta:: + :description: API references on Authlib OAuth 1.0 HTTP session clients. + +Requests OAuth 1.0 +------------------ + +.. module:: authlib.integrations.requests_client + +.. autoclass:: OAuth1Session + :members: + create_authorization_url, + fetch_request_token, + fetch_access_token, + parse_authorization_response + +.. autoclass:: OAuth1Auth + :members: + + +HTTPX OAuth 1.0 +--------------- + +.. module:: authlib.integrations.httpx_client + +.. autoclass:: OAuth1Auth + :members: + +.. autoclass:: OAuth1Client + :members: + create_authorization_url, + fetch_request_token, + fetch_access_token, + parse_authorization_response + +.. autoclass:: AsyncOAuth1Client + :members: + create_authorization_url, + fetch_request_token, + fetch_access_token, + parse_authorization_response diff --git a/docs/oauth1/client/http/httpx.rst b/docs/oauth1/client/http/httpx.rst new file mode 100644 index 000000000..55f887203 --- /dev/null +++ b/docs/oauth1/client/http/httpx.rst @@ -0,0 +1,53 @@ +.. _httpx_oauth1_client: + +OAuth 1.0 for HTTPX +=================== + +.. meta:: + :description: An OAuth 1.0 Client implementation for HTTPX, + powered by Authlib. + +.. module:: authlib.integrations.httpx_client + :noindex: + +HTTPX is a next-generation HTTP client for Python. Authlib enables OAuth 1.0 +for HTTPX with: + +* :class:`OAuth1Client` +* :class:`AsyncOAuth1Client` + +.. note:: HTTPX is still in its "alpha" stage, use it with caution. + +HTTPX OAuth 1.0 +--------------- + +There are three steps in OAuth 1 to obtain an access token: + +1. fetch a temporary credential +2. visit the authorization page +3. exchange access token with the temporary credential + +It shares a common API design with :ref:`requests_client`. + +Read the common guide of :ref:`OAuth 1 Session ` to understand the whole OAuth +1.0 flow. + + +Async OAuth 1.0 +--------------- + +The async version of :class:`AsyncOAuth1Client` works the same as +:ref:`OAuth 1 Session `, except that we need to add ``await`` when +required:: + + # fetching request token + request_token = await client.fetch_request_token(request_token_url) + + # fetching access token + access_token = await client.fetch_access_token(access_token_url) + + # normal requests + await client.get(...) + await client.post(...) + await client.put(...) + await client.delete(...) diff --git a/docs/client/oauth1.rst b/docs/oauth1/client/http/index.rst similarity index 95% rename from docs/client/oauth1.rst rename to docs/oauth1/client/http/index.rst index 768c599ec..12492f899 100644 --- a/docs/client/oauth1.rst +++ b/docs/oauth1/client/http/index.rst @@ -1,7 +1,7 @@ .. _oauth_1_session: -OAuth 1 Session -=============== +HTTP Clients +============ .. meta:: :description: An OAuth 1.0 protocol Client implementation for Python @@ -21,12 +21,15 @@ Authlib provides three implementations of OAuth 1.0 client: :class:`requests_client.OAuth1Session` and :class:`httpx_client.AsyncOAuth1Client` shares the same API. -There are also frameworks integrations of :ref:`flask_client`, :ref:`django_client` -and :ref:`starlette_client`. If you are using these frameworks, you may have interests -in their own documentation. - If you are not familiar with OAuth 1.0, it is better to read :ref:`intro_oauth1` now. +.. toctree:: + :maxdepth: 1 + + requests + httpx + api + Initialize OAuth 1.0 Client --------------------------- @@ -78,7 +81,7 @@ The second step is to generate the authorization URL:: 'https://api.twitter.com/oauth/authenticate?oauth_token=gA..H' Actually, the second parameter ``request_token`` can be omitted, since session -is re-used:: +is reused:: >>> client.create_authorization_url(authenticate_url) @@ -141,7 +144,7 @@ session:: Access Protected Resources -------------------------- -Now you can access the protected resources. If you re-use the session, you +Now you can access the protected resources. If you reuse the session, you don't need to do anything:: >>> account_url = 'https://api.twitter.com/1.1/account/verify_credentials.json' @@ -181,7 +184,7 @@ Create an instance of OAuth1Auth with an access token:: auth = OAuth1Auth( client_id='..', - client_secret=client_secret='..', + client_secret='..', token='oauth_token value', token_secret='oauth_token_secret value', ... @@ -200,3 +203,4 @@ If using ``httpx``, pass this ``auth`` to access protected resources:: url = 'https://api.twitter.com/1.1/account/verify_credentials.json' resp = await httpx.get(url, auth=auth) + diff --git a/docs/oauth1/client/http/requests.rst b/docs/oauth1/client/http/requests.rst new file mode 100644 index 000000000..c2a9e0719 --- /dev/null +++ b/docs/oauth1/client/http/requests.rst @@ -0,0 +1,36 @@ +.. _requests_oauth1_client: + +OAuth 1.0 for Requests +====================== + +.. meta:: + :description: An OAuth 1.0 Client implementation for Python requests, + powered by Authlib. + +.. module:: authlib.integrations.requests_client + :noindex: + +Requests is a very popular HTTP library for Python. Authlib enables OAuth 1.0 +for Requests with its :class:`OAuth1Session` and :class:`OAuth1Auth`. + + +OAuth1Session +~~~~~~~~~~~~~ + +The requests integration follows our common guide of :ref:`OAuth 1 Session `. +Follow the documentation in :ref:`OAuth 1 Session ` instead. + +OAuth1Auth +~~~~~~~~~~ + +It is also possible to use :class:`OAuth1Auth` directly with requests. +After we obtained access token from an OAuth 1.0 provider, we can construct +an ``auth`` instance for requests:: + + auth = OAuth1Auth( + client_id='YOUR-CLIENT-ID', + client_secret='YOUR-CLIENT-SECRET', + token='oauth_token', + token_secret='oauth_token_secret', + ) + requests.get(url, auth=auth) diff --git a/docs/oauth1/client/index.rst b/docs/oauth1/client/index.rst new file mode 100644 index 000000000..13b264d55 --- /dev/null +++ b/docs/oauth1/client/index.rst @@ -0,0 +1,40 @@ +.. _oauth1_client: + +Client +====== + +.. meta:: + :description: Python OAuth 1.0 Client implementations with requests, HTTPX, + Flask, Django and Starlette, powered by Authlib. + +Authlib provides OAuth 1.0 client implementations for two distinct use cases: + +**HTTP Clients** — your Python code fetches tokens and calls APIs directly. +Suitable for scripts, CLIs, service-to-service communication:: + + from authlib.integrations.requests_client import OAuth1Session + + client = OAuth1Session(client_id, client_secret) + request_token = client.fetch_request_token(request_token_url) + # ... redirect user, then: + token = client.fetch_access_token(access_token_url) + resp = client.get('https://api.example.com/data') + +**Web Clients** — your web application delegates authentication to an OAuth 1.0 +provider. Works with any OAuth 1.0 provider (Twitter, or your own). Integrations +for Flask, Django, Starlette and FastAPI:: + + from authlib.integrations.flask_client import OAuth + + oauth = OAuth(app) + twitter = oauth.register('twitter', {...}) + + @app.route('/login') + def login(): + return twitter.authorize_redirect(url_for('authorize', _external=True)) + +.. toctree:: + :maxdepth: 2 + + http/index + web/index diff --git a/docs/oauth1/client/web/api.rst b/docs/oauth1/client/web/api.rst new file mode 100644 index 000000000..82da1682b --- /dev/null +++ b/docs/oauth1/client/web/api.rst @@ -0,0 +1,38 @@ +Reference +========= + +.. meta:: + :description: API references on Authlib OAuth 1.0 web framework client integrations. + +Flask Registry and RemoteApp +----------------------------- + +.. module:: authlib.integrations.flask_client + +.. autoclass:: OAuth + :members: + init_app, + register, + create_client + + +Django Registry and RemoteApp +------------------------------ + +.. module:: authlib.integrations.django_client + +.. autoclass:: OAuth + :members: + register, + create_client + + +Starlette Registry and RemoteApp +--------------------------------- + +.. module:: authlib.integrations.starlette_client + +.. autoclass:: OAuth + :members: + register, + create_client diff --git a/docs/oauth1/client/web/django.rst b/docs/oauth1/client/web/django.rst new file mode 100644 index 000000000..90c48821e --- /dev/null +++ b/docs/oauth1/client/web/django.rst @@ -0,0 +1,91 @@ +.. _django_oauth1_client: + +Django Integration +================== + +.. meta:: + :description: The built-in Django integrations for OAuth 1.0 + clients, powered by Authlib. + +.. module:: authlib.integrations.django_client + :noindex: + +Looking for OAuth 1.0 provider? + +- :ref:`django_oauth1_server` + +The Django client can handle OAuth 1 services. Authlib has a shared API design +among framework integrations. Get started with :ref:`frameworks_oauth1_clients`. + +Create a registry with :class:`OAuth` object:: + + from authlib.integrations.django_client import OAuth + + oauth = OAuth() + +The common use case for OAuth is authentication, e.g. let your users log in +with Twitter. + +.. important:: + + Please read :ref:`frameworks_oauth1_clients` at first. Authlib has a shared + API design among framework integrations, learn them from + :ref:`frameworks_oauth1_clients`. + + +Configuration +------------- + +Authlib Django OAuth registry can load the configuration from your Django +application settings automatically. Every key value pair can be omitted. +They can be configured from your Django settings:: + + AUTHLIB_OAUTH_CLIENTS = { + 'twitter': { + 'client_id': 'Twitter Consumer Key', + 'client_secret': 'Twitter Consumer Secret', + 'request_token_url': 'https://api.twitter.com/oauth/request_token', + 'request_token_params': None, + 'access_token_url': 'https://api.twitter.com/oauth/access_token', + 'access_token_params': None, + 'authorize_url': 'https://api.twitter.com/oauth/authenticate', + 'api_base_url': 'https://api.twitter.com/1.1/', + 'client_kwargs': None + } + } + +There are differences between OAuth 1.0 and OAuth 2.0, please check the +parameters in ``.register`` in :ref:`frameworks_oauth1_clients`. + +Saving Temporary Credential +--------------------------- + +In OAuth 1.0, we need to use a temporary credential to exchange access token, +this temporary credential was created before redirecting to the provider (Twitter), +we need to save this temporary credential somewhere in order to use it later. + +In OAuth 1, Django client will save the request token in sessions. In this +case, you just need to configure Session Middleware in Django:: + + MIDDLEWARE = [ + 'django.contrib.sessions.middleware.SessionMiddleware' + ] + +Follow the official Django documentation to set a proper session. Either a +database backend or a cache backend would work well. + +.. warning:: + + Be aware, using secure cookie as session backend will expose your request + token. + +Routes for Authorization +------------------------ + +Just like the example in :ref:`frameworks_oauth1_clients`, everything is the same. +But there is a hint to create ``redirect_uri`` with ``request`` in Django:: + + def login(request): + # build a full authorize callback uri + redirect_uri = request.build_absolute_uri('/authorize') + return oauth.twitter.authorize_redirect(request, redirect_uri) diff --git a/docs/oauth1/client/web/fastapi.rst b/docs/oauth1/client/web/fastapi.rst new file mode 100644 index 000000000..d5c2b74c0 --- /dev/null +++ b/docs/oauth1/client/web/fastapi.rst @@ -0,0 +1,51 @@ +.. _fastapi_oauth1_client: + +FastAPI Integration +=================== + +.. meta:: + :description: Use Authlib built-in Starlette integrations to build + OAuth 1.0 clients for FastAPI. + +.. module:: authlib.integrations.starlette_client + :noindex: + +FastAPI_ is a modern, fast (high-performance), web framework for building +APIs with Python 3.6+ based on standard Python type hints. It is built on +top of **Starlette**, that means most of the code looks similar with +Starlette code. You should first read documentation of: + +1. :ref:`frameworks_oauth1_clients` +2. :ref:`starlette_oauth1_client` + +Here is how you would create a FastAPI application:: + + from fastapi import FastAPI + from starlette.middleware.sessions import SessionMiddleware + + app = FastAPI() + # we need this to save temporary credential in session + app.add_middleware(SessionMiddleware, secret_key="some-random-string") + +Since Authlib starlette requires using ``request`` instance, we need to +expose that ``request`` to Authlib. According to the documentation on +`Using the Request Directly `_:: + + from starlette.requests import Request + + @app.get("/login/twitter") + async def login_via_twitter(request: Request): + redirect_uri = request.url_for('auth_via_twitter') + return await oauth.twitter.authorize_redirect(request, redirect_uri) + + @app.get("/auth/twitter") + async def auth_via_twitter(request: Request): + token = await oauth.twitter.authorize_access_token(request) + # do something with the token + return dict(token) + +.. _FastAPI: https://fastapi.tiangolo.com/ + +We have a blog post about how to create Twitter login in FastAPI: + +https://blog.authlib.org/2020/fastapi-twitter-login diff --git a/docs/oauth1/client/web/flask.rst b/docs/oauth1/client/web/flask.rst new file mode 100644 index 000000000..12efda483 --- /dev/null +++ b/docs/oauth1/client/web/flask.rst @@ -0,0 +1,182 @@ +.. _flask_oauth1_client: + +Flask Integration +================= + +.. meta:: + :description: The built-in Flask integrations for OAuth 1.0 + clients, powered by Authlib. + +.. module:: authlib.integrations.flask_client + :noindex: + +This documentation covers OAuth 1.0 Client support for Flask. Looking for +OAuth 1.0 provider? + +- :ref:`flask_oauth1_server` + +Flask OAuth client can handle OAuth 1 services. It shares a similar API with +Flask-OAuthlib, you can transfer your code from Flask-OAuthlib to Authlib with +ease. + +Create a registry with :class:`OAuth` object:: + + from authlib.integrations.flask_client import OAuth + + oauth = OAuth(app) + +You can also initialize it later with :meth:`~OAuth.init_app` method:: + + oauth = OAuth() + oauth.init_app(app) + +.. important:: + + Please read :ref:`frameworks_oauth1_clients` at first. Authlib has a shared + API design among framework integrations, learn them from + :ref:`frameworks_oauth1_clients`. + +Configuration +------------- + +Authlib Flask OAuth registry can load the configuration from Flask ``app.config`` +automatically. Every key-value pair in ``.register`` can be omitted. They can be +configured in your Flask App configuration. Config keys are formatted as +``{name}_{key}`` in uppercase, e.g. + +========================== ================================ +TWITTER_CLIENT_ID Twitter Consumer Key +TWITTER_CLIENT_SECRET Twitter Consumer Secret +TWITTER_REQUEST_TOKEN_URL URL to fetch OAuth request token +========================== ================================ + +If you register your remote app as ``oauth.register('example', ...)``, the +config keys would look like: + +========================== =============================== +EXAMPLE_CLIENT_ID OAuth Consumer Key +EXAMPLE_CLIENT_SECRET OAuth Consumer Secret +EXAMPLE_REQUEST_TOKEN_URL URL to fetch OAuth request token +========================== =============================== + +Here is a full list of the configuration keys: + +- ``{name}_CLIENT_ID``: Client key of OAuth 1 +- ``{name}_CLIENT_SECRET``: Client secret of OAuth 1 +- ``{name}_REQUEST_TOKEN_URL``: Request Token endpoint for OAuth 1 +- ``{name}_REQUEST_TOKEN_PARAMS``: Extra parameters for Request Token endpoint +- ``{name}_ACCESS_TOKEN_URL``: Access Token endpoint for OAuth 1 +- ``{name}_ACCESS_TOKEN_PARAMS``: Extra parameters for Access Token endpoint +- ``{name}_AUTHORIZE_URL``: Endpoint for user authorization of OAuth 1 +- ``{name}_AUTHORIZE_PARAMS``: Extra parameters for Authorization Endpoint. +- ``{name}_API_BASE_URL``: A base URL endpoint to make requests simple +- ``{name}_CLIENT_KWARGS``: Extra keyword arguments for OAuth1Session + + +Using Cache for Temporary Credential +------------------------------------- + +By default, the Flask OAuth registry will use Flask session to store OAuth 1.0 +temporary credential (request token). However, in this way, there are chances +your temporary credential will be exposed. + +Our ``OAuth`` registry provides a simple way to store temporary credentials in a +cache system. When initializing ``OAuth``, you can pass an ``cache`` instance:: + + oauth = OAuth(app, cache=cache) + + # or initialize lazily + oauth = OAuth() + oauth.init_app(app, cache=cache) + +A ``cache`` instance MUST have methods: + +- ``.delete(key)`` +- ``.get(key)`` +- ``.set(key, value, expires=None)`` + +An example of a ``cache`` instance can be: + +.. code-block:: python + + from flask import Flask + + class OAuthCache: + + def __init__(self, app: Flask) -> None: + self.app = app + + def delete(self, key: str) -> None: + pass + + def get(self, key: str) -> str | None: + pass + + def set(self, key: str, value: str, expires: int | None = None) -> None: + pass + + +Routes for Authorization +------------------------ + +Unlike the examples in :ref:`frameworks_oauth1_clients`, Flask does not pass a +``request`` into routes. In this case, the routes for authorization should look +like:: + + from flask import url_for, redirect + + @app.route('/login') + def login(): + redirect_uri = url_for('authorize', _external=True) + return oauth.twitter.authorize_redirect(redirect_uri) + + @app.route('/authorize') + def authorize(): + token = oauth.twitter.authorize_access_token() + resp = oauth.twitter.get('account/verify_credentials.json') + resp.raise_for_status() + profile = resp.json() + # do something with the token and profile + return redirect('/') + +Accessing OAuth Resources +------------------------- + +There is no ``request`` in accessing OAuth resources either. Just like above, +we don't need to pass the ``request`` parameter, everything is handled by Authlib +automatically:: + + from flask import render_template + + @app.route('/twitter') + def show_twitter_timeline(): + resp = oauth.twitter.get('statuses/user_timeline.json') + resp.raise_for_status() + tweets = resp.json() + return render_template('twitter.html', tweets=tweets) + +In this case, our ``fetch_token`` could look like:: + + from your_project import current_user + + def fetch_token(name): + token = OAuth1Token.find( + name=name, + user=current_user, + ) + return token.to_token() + + # initialize the OAuth registry with this fetch_token function + oauth = OAuth(fetch_token=fetch_token) + +You don't have to pass ``token``, you don't have to pass ``request``. That +is the fantasy of Flask. + +Examples +--------- + +Here are some example projects for you to learn Flask OAuth 1.0 client integrations: + +1. `Flask Twitter Login`_. + +.. _`Flask Twitter Login`: https://github.com/authlib/demo-oauth-client/tree/master/flask-twitter-tool diff --git a/docs/oauth1/client/web/index.rst b/docs/oauth1/client/web/index.rst new file mode 100644 index 000000000..aff214828 --- /dev/null +++ b/docs/oauth1/client/web/index.rst @@ -0,0 +1,247 @@ +.. _frameworks_oauth1_clients: + +Web Clients +=========== + +.. module:: authlib.integrations + :noindex: + +This documentation covers OAuth 1.0 integrations for Python Web Frameworks like: + +* Django: The web framework for perfectionists with deadlines +* Flask: The Python micro framework for building web applications +* Starlette: The little ASGI framework that shines + + +Authlib shares a common API design among these web frameworks. Instead +of introducing them one by one, this documentation contains the common +usage for them all. + +We start with creating a registry with the ``OAuth`` class:: + + # for Flask framework + from authlib.integrations.flask_client import OAuth + + # for Django framework + from authlib.integrations.django_client import OAuth + + # for Starlette framework + from authlib.integrations.starlette_client import OAuth + + oauth = OAuth() + +There are little differences among each framework, you can read their +documentation later: + +1. :class:`flask_client.OAuth` for :ref:`flask_oauth1_client` +2. :class:`django_client.OAuth` for :ref:`django_oauth1_client` +3. :class:`starlette_client.OAuth` for :ref:`starlette_oauth1_client` + +The common use case for OAuth is authentication, e.g. let your users log in +with Twitter. + +Log In with OAuth 1.0 +--------------------- + +For instance, Twitter is an OAuth 1.0 service, you want your users to log in +your website with Twitter. + +The first step is register a remote application on the ``OAuth`` registry via +``oauth.register`` method:: + + oauth.register( + name='twitter', + client_id='{{ your-twitter-consumer-key }}', + client_secret='{{ your-twitter-consumer-secret }}', + request_token_url='https://api.twitter.com/oauth/request_token', + request_token_params=None, + access_token_url='https://api.twitter.com/oauth/access_token', + access_token_params=None, + authorize_url='https://api.twitter.com/oauth/authenticate', + authorize_params=None, + api_base_url='https://api.twitter.com/1.1/', + client_kwargs=None, + ) + +The first parameter in ``register`` method is the **name** of the remote +application. You can access the remote application with:: + + twitter = oauth.create_client('twitter') + # or simply with + twitter = oauth.twitter + +The configuration of those parameters can be loaded from the framework +configuration. Each framework has its own config system, read the framework +specified documentation later. + +For instance, if ``client_id`` and ``client_secret`` can be loaded via +configuration, we can simply register the remote app with:: + + oauth.register( + name='twitter', + request_token_url='https://api.twitter.com/oauth/request_token', + access_token_url='https://api.twitter.com/oauth/access_token', + authorize_url='https://api.twitter.com/oauth/authenticate', + api_base_url='https://api.twitter.com/1.1/', + ) + +The ``client_kwargs`` is a dict configuration to pass extra parameters to +:ref:`OAuth 1 Session `. If you are using ``RSA-SHA1`` signature method:: + + client_kwargs = { + 'signature_method': 'RSA-SHA1', + 'signature_type': 'HEADER', + 'rsa_key': 'Your-RSA-Key' + } + + +Saving Temporary Credential +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Usually, the framework integration has already implemented this part through +the framework session system. All you need to do is enable session for the +chosen framework. + +Routes for Authorization +~~~~~~~~~~~~~~~~~~~~~~~~ + +After configuring the ``OAuth`` registry and the remote application, the +rest steps are much simpler. The only required parts are routes: + +1. redirect to 3rd party provider (Twitter) for authentication +2. redirect back to your website to fetch access token and profile + +Here is the example for Twitter login:: + + def login(request): + twitter = oauth.create_client('twitter') + redirect_uri = 'https://example.com/authorize' + return twitter.authorize_redirect(request, redirect_uri) + + def authorize(request): + twitter = oauth.create_client('twitter') + token = twitter.authorize_access_token(request) + resp = twitter.get('account/verify_credentials.json') + resp.raise_for_status() + profile = resp.json() + # do something with the token and profile + return '...' + +After user confirmed on Twitter authorization page, it will redirect +back to your website ``authorize`` page. In this route, you can get your +user's twitter profile information, you can store the user information +in your database, mark your user as logged in and etc. + + +Accessing OAuth Resources +------------------------- + +.. note:: + + If your application ONLY needs login via 3rd party services like + Twitter to login, you DON'T need to create the token database. + +There are also chances that you need to access your user's 3rd party +OAuth provider resources. For instance, you want to display the logged +in user's twitter time line. You will use **access token** to fetch +the resources:: + + def get_twitter_tweets(request): + token = OAuth1Token.find( + name='twitter', + user=request.user + ) + # API URL: https://api.twitter.com/1.1/statuses/user_timeline.json + resp = oauth.twitter.get('statuses/user_timeline.json', token=token.to_token()) + resp.raise_for_status() + return resp.json() + +In this case, we need a place to store the access token in order to use +it later. Usually we will save the token into database. In the previous +**Routes for Authorization** ``authorize`` part, we can save the token into +database. + + +Design Database +~~~~~~~~~~~~~~~ + +Here are some hints on how to design the OAuth 1.0 token database:: + + class OAuth1Token(Model): + name = String(length=40) + oauth_token = String(length=200) + oauth_token_secret = String(length=200) + user = ForeignKey(User) + + def to_token(self): + return dict( + oauth_token=self.access_token, + oauth_token_secret=self.alt_token, + ) + + +And then we can save user's access token into database when user was redirected +back to our ``authorize`` page. + + +Fetch User OAuth Token +~~~~~~~~~~~~~~~~~~~~~~ + +You can always pass a ``token`` parameter to the remote application request +methods, like:: + + token = OAuth1Token.find(name='twitter', user=request.user) + oauth.twitter.get(url, token=token) + oauth.twitter.post(url, token=token) + oauth.twitter.put(url, token=token) + oauth.twitter.delete(url, token=token) + +However, it is not a good practice to query the token database in every request +function. Authlib provides a way to fetch current user's token automatically for +you, just ``register`` with ``fetch_token`` function:: + + def fetch_twitter_token(request): + token = OAuth1Token.find( + name='twitter', + user=request.user + ) + return token.to_token() + + # we can registry this ``fetch_token`` with oauth.register + oauth.register( + 'twitter', + # ... + fetch_token=fetch_twitter_token, + ) + +There is also a shared way to fetch token:: + + def fetch_token(name, request): + token = OAuth1Token.find( + name=name, + user=request.user + ) + return token.to_token() + + # initialize OAuth registry with this fetch_token function + oauth = OAuth(fetch_token=fetch_token) + +Now, developers don't have to pass a ``token`` in the HTTP requests, +instead, they can pass the ``request``:: + + def get_twitter_tweets(request): + resp = oauth.twitter.get('statuses/user_timeline.json', request=request) + resp.raise_for_status() + return resp.json() + + +.. note:: Flask is different, you don't need to pass the ``request`` either. + +.. toctree:: + :maxdepth: 1 + + flask + django + starlette + fastapi + api diff --git a/docs/oauth1/client/web/starlette.rst b/docs/oauth1/client/web/starlette.rst new file mode 100644 index 000000000..ec3c3485b --- /dev/null +++ b/docs/oauth1/client/web/starlette.rst @@ -0,0 +1,82 @@ +.. _starlette_oauth1_client: + +Starlette Integration +===================== + +.. meta:: + :description: The built-in Starlette integrations for OAuth 1.0 + clients, powered by Authlib. + +.. module:: authlib.integrations.starlette_client + :noindex: + +Starlette_ is a lightweight ASGI framework/toolkit, which is ideal for +building high performance asyncio services. + +.. _Starlette: https://www.starlette.io/ + +This documentation covers OAuth 1.0 Client support for Starlette. Because all +the frameworks integrations share the same API, it is best to: + +Read :ref:`frameworks_oauth1_clients` at first. + +The difference between Starlette and Flask/Django integrations is Starlette +is **async**. We will use ``await`` for the functions we need to call. But +first, let's create an :class:`OAuth` instance:: + + from authlib.integrations.starlette_client import OAuth + + oauth = OAuth() + +Unlike Flask and Django, Starlette OAuth registry uses HTTPX +:class:`~authlib.integrations.httpx_client.AsyncOAuth1Client` as the OAuth 1.0 +backend. + + +Enable Session for OAuth 1.0 +----------------------------- + +With OAuth 1.0, we need to use a temporary credential to exchange for an access +token. This temporary credential is created before redirecting to the provider +(Twitter), and needs to be saved somewhere in order to use it later. + +With OAuth 1, the Starlette client will save the request token in sessions. To +enable this, we need to add the ``SessionMiddleware`` middleware to the +application, which requires the installation of the ``itsdangerous`` package:: + + from starlette.applications import Starlette + from starlette.middleware.sessions import SessionMiddleware + + app = Starlette() + app.add_middleware(SessionMiddleware, secret_key="some-random-string") + +However, using the ``SessionMiddleware`` will store the temporary credential as +a secure cookie which will expose your request token to the client. + +Routes for Authorization +------------------------ + +Just like the examples in :ref:`frameworks_oauth1_clients`, but Starlette is +**async**, the routes for authorization should look like:: + + @app.route('/login/twitter') + async def login_via_twitter(request): + twitter = oauth.create_client('twitter') + redirect_uri = request.url_for('authorize_twitter') + return await twitter.authorize_redirect(request, redirect_uri) + + @app.route('/auth/twitter') + async def authorize_twitter(request): + twitter = oauth.create_client('twitter') + token = await twitter.authorize_access_token(request) + resp = await twitter.get('account/verify_credentials.json') + profile = resp.json() + # do something with the token and profile + return '...' + +Examples +-------- + +We have Starlette demos at https://github.com/authlib/demo-oauth-client + +1. OAuth 1.0: `Starlette Twitter login `_ diff --git a/docs/oauth/1/intro.rst b/docs/oauth1/concepts.rst similarity index 99% rename from docs/oauth/1/intro.rst rename to docs/oauth1/concepts.rst index bf4e12da6..b8855b8ad 100644 --- a/docs/oauth/1/intro.rst +++ b/docs/oauth1/concepts.rst @@ -5,8 +5,8 @@ .. _intro_oauth1: -Introduce OAuth 1.0 -=================== +Concepts +======== OAuth 1.0 is the standardization and combined wisdom of many well established industry protocols at its creation time. It was first introduced as Twitter's open protocol. It is similar to other protocols diff --git a/docs/oauth1/index.rst b/docs/oauth1/index.rst new file mode 100644 index 000000000..d6a375f42 --- /dev/null +++ b/docs/oauth1/index.rst @@ -0,0 +1,10 @@ +OAuth 1.0 +========= + +.. toctree:: + :maxdepth: 2 + + concepts + client/index + provider/index + specs/index diff --git a/docs/django/1/api.rst b/docs/oauth1/provider/django/api.rst similarity index 100% rename from docs/django/1/api.rst rename to docs/oauth1/provider/django/api.rst diff --git a/docs/django/1/authorization-server.rst b/docs/oauth1/provider/django/authorization-server.rst similarity index 100% rename from docs/django/1/authorization-server.rst rename to docs/oauth1/provider/django/authorization-server.rst diff --git a/docs/django/1/index.rst b/docs/oauth1/provider/django/index.rst similarity index 85% rename from docs/django/1/index.rst rename to docs/oauth1/provider/django/index.rst index 2a70170db..80e64b7b6 100644 --- a/docs/django/1/index.rst +++ b/docs/oauth1/provider/django/index.rst @@ -1,7 +1,7 @@ .. _django_oauth1_server: -Django OAuth 1.0 Server -======================= +Django Integration +================== .. meta:: :description: How to create an OAuth 1.0 server in Django with Authlib. @@ -23,7 +23,7 @@ At the very beginning, we need to have some basic understanding of export AUTHLIB_INSECURE_TRANSPORT=true -Looking for Django OAuth 1.0 client? Check out :ref:`django_client`. +Looking for Django OAuth 1.0 client? Check out :ref:`django_oauth1_client`. .. toctree:: :maxdepth: 2 diff --git a/docs/django/1/resource-server.rst b/docs/oauth1/provider/django/resource-server.rst similarity index 97% rename from docs/django/1/resource-server.rst rename to docs/oauth1/provider/django/resource-server.rst index 96340424d..7c0efe26d 100644 --- a/docs/django/1/resource-server.rst +++ b/docs/oauth1/provider/django/resource-server.rst @@ -11,7 +11,7 @@ server. Here is the way to protect your users' resources:: from authlib.integrations.django_oauth1 import ResourceProtector require_oauth = ResourceProtector(Client, TokenCredential) - @require_oauth() + @require_oauth def user_api(request): user = request.oauth1_credential.user return JsonResponse(dict(username=user.username)) diff --git a/docs/oauth1/provider/flask/api.rst b/docs/oauth1/provider/flask/api.rst new file mode 100644 index 000000000..175feaba0 --- /dev/null +++ b/docs/oauth1/provider/flask/api.rst @@ -0,0 +1,19 @@ +Reference +========= + +This part of the documentation covers the interface of Flask OAuth 1.0 +Server. + +.. module:: authlib.integrations.flask_oauth1 + +.. autoclass:: AuthorizationServer + :members: + +.. autoclass:: ResourceProtector + :member-order: bysource + :members: + +.. data:: current_credential + + Routes protected by :class:`ResourceProtector` can access current credential + with this variable. diff --git a/docs/flask/1/authorization-server.rst b/docs/oauth1/provider/flask/authorization-server.rst similarity index 54% rename from docs/flask/1/authorization-server.rst rename to docs/oauth1/provider/flask/authorization-server.rst index ee37bab9d..3537c8a32 100644 --- a/docs/flask/1/authorization-server.rst +++ b/docs/oauth1/provider/flask/authorization-server.rst @@ -6,6 +6,10 @@ authorization, and issuing token credentials. When the resource owner (user) grants the authorization, this server will issue a token credential to the client. +.. versionchanged:: v1.0.0 + We have removed built-in SQLAlchemy integrations. + + Resource Owner -------------- @@ -31,17 +35,30 @@ information: - Client Password, usually called **client_secret** - Client RSA Public Key (if RSA-SHA1 signature method supported) -Authlib has provided a mixin for SQLAlchemy, define the client with this mixin:: +Developers MUST implement the missing methods of ``authlib.oauth1.ClientMixin``, take an +example of Flask-SQAlchemy:: - from authlib.integrations.sqla_oauth1 import OAuth1ClientMixin + from authlib.oauth1 import ClientMixin - class Client(db.Model, OAuth1ClientMixin): + class Client(ClientMixin, db.Model): id = db.Column(db.Integer, primary_key=True) + client_id = db.Column(db.String(48), index=True) + client_secret = db.Column(db.String(120), nullable=False) + default_redirect_uri = db.Column(db.Text, nullable=False, default='') user_id = db.Column( db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') ) user = db.relationship('User') + def get_default_redirect_uri(self): + return self.default_redirect_uri + + def get_client_secret(self): + return self.client_secret + + def get_rsa_public_key(self): + return None + A client is registered by a user (developer) on your website. Get a deep inside with :class:`~authlib.oauth1.rfc5849.ClientMixin` API reference. @@ -50,7 +67,7 @@ Temporary Credentials A temporary credential is used to exchange a token credential. It is also known as "request token and secret". Since it is temporary, it is better to -save them into cache instead of database. A cache instance should has these +save them into cache instead of database. A cache instance should have these methods: - ``.get(key)`` @@ -58,19 +75,37 @@ methods: - ``.delete(key)`` A cache can be a memcache, redis or something else. If cache is not available, -there is also a SQLAlchemy mixin:: +developers can also implement it with database. For example, using SQLAlchemy:: - from authlib.integrations.sqla_oauth1 import OAuth1TemporaryCredentialMixin + from authlib.oauth1 import TemporaryCredentialMixin - class TemporaryCredential(db.Model, OAuth1TemporaryCredentialMixin): + class TemporaryCredential(TemporaryCredentialMixin, db.Model): id = db.Column(db.Integer, primary_key=True) user_id = db.Column( db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') ) user = db.relationship('User') + client_id = db.Column(db.String(48), index=True) + oauth_token = db.Column(db.String(84), unique=True, index=True) + oauth_token_secret = db.Column(db.String(84)) + oauth_verifier = db.Column(db.String(84)) + oauth_callback = db.Column(db.Text, default='') + + def get_client_id(self): + return self.client_id + + def get_redirect_uri(self): + return self.oauth_callback + + def check_verifier(self, verifier): + return self.oauth_verifier == verifier + + def get_oauth_token(self): + return self.oauth_token + + def get_oauth_token_secret(self): + return self.oauth_token_secret -To make a Temporary Credentials model yourself, get more information with -:class:`~authlib.oauth1.rfc5849.ClientMixin` API reference. Token Credentials ----------------- @@ -79,23 +114,27 @@ A token credential is used to access resource owners' resources. Unlike OAuth 2, the token credential will not expire in OAuth 1. This token credentials are supposed to be saved into a persist database rather than a cache. -Here is a SQLAlchemy mixin for easy integration:: +Developers MUST implement :class:`~authlib.oauth1.rfc5849.TokenCredentialMixin` +missing methods. Here is an example of SQLAlchemy integration:: - from authlib.integrations.sqla_oauth1 import OAuth1TokenCredentialMixin + from authlib.oauth1 import TokenCredentialMixin - class TokenCredential(db.Model, OAuth1TokenCredentialMixin): + class TokenCredential(TokenCredentialMixin, db.Model): id = db.Column(db.Integer, primary_key=True) user_id = db.Column( db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') ) user = db.relationship('User') + client_id = db.Column(db.String(48), index=True) + oauth_token = db.Column(db.String(84), unique=True, index=True) + oauth_token_secret = db.Column(db.String(84)) + + def get_oauth_token(self): + return self.oauth_token - def set_user_id(self, user_id): - self.user_id = user_id + def get_oauth_token_secret(self): + return self.oauth_token_secret -If SQLAlchemy is not what you want, read the API reference of -:class:`~authlib.oauth1.rfc5849.TokenCredentialMixin` and implement the missing -methods. Timestamp and Nonce ------------------- @@ -104,12 +143,21 @@ The nonce value MUST be unique across all requests with the same timestamp, client credentials, and token combinations. Authlib Flask integration has a built-in validation with cache. -If cache is not available, there is also a SQLAlchemy mixin:: +If cache is not available, developers can use a database, here is an example of +using SQLAlchemy:: - from authlib.integrations.sqla_oauth1 import OAuth1TimestampNonceMixin - - class TimestampNonce(db.Model, OAuth1TimestampNonceMixin) + class TimestampNonce(db.Model): + __table_args__ = ( + db.UniqueConstraint( + 'client_id', 'timestamp', 'nonce', 'oauth_token', + name='unique_nonce' + ), + ) id = db.Column(db.Integer, primary_key=True) + client_id = db.Column(db.String(48), nullable=False) + timestamp = db.Column(db.Integer, nullable=False) + nonce = db.Column(db.String(48), nullable=False) + oauth_token = db.Column(db.String(84)) Define A Server @@ -120,9 +168,10 @@ Authlib provides a ready to use which has built-in tools to handle requests and responses:: from authlib.integrations.flask_oauth1 import AuthorizationServer - from authlib.integrations.sqla_oauth1 import create_query_client_func - query_client = create_query_client_func(db.session, Client) + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + server = AuthorizationServer(app, query_client=query_client) It can also be initialized lazily with init_app:: @@ -176,23 +225,86 @@ can take the advantage with:: register_nonce_hooks, register_temporary_credential_hooks ) - from authlib.integrations.sqla_oauth1 import register_token_credential_hooks register_nonce_hooks(server, cache) register_temporary_credential_hooks(server, cache) - register_token_credential_hooks(server, db.session, TokenCredential) -If cache is not available, here are the helpers for SQLAlchemy:: +If cache is not available, developers MUST register the hooks with the database we +defined above:: - from authlib.integrations.sqla_oauth1 import ( - register_nonce_hooks, - register_temporary_credential_hooks, - register_token_credential_hooks - ) + # check if nonce exists - register_nonce_hooks(server, db.session, TimestampNonce) - register_temporary_credential_hooks(server, db.session, TemporaryCredential) - register_token_credential_hooks(server, db.session, TokenCredential) + def exists_nonce(nonce, timestamp, client_id, oauth_token): + q = TimestampNonce.query.filter_by( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + ) + if oauth_token: + q = q.filter_by(oauth_token=oauth_token) + rv = q.first() + if rv: + return True + + item = TimestampNonce( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + oauth_token=oauth_token, + ) + db.session.add(item) + db.session.commit() + return False + server.register_hook('exists_nonce', exists_nonce) + + # hooks for temporary credential + + def create_temporary_credential(token, client_id, redirect_uri): + item = TemporaryCredential( + client_id=client_id, + oauth_token=token['oauth_token'], + oauth_token_secret=token['oauth_token_secret'], + oauth_callback=redirect_uri, + ) + db.session.add(item) + db.session.commit() + return item + + def get_temporary_credential(oauth_token): + return TemporaryCredential.query.filter_by(oauth_token=oauth_token).first() + + def delete_temporary_credential(oauth_token): + q = TemporaryCredential.query.filter_by(oauth_token=oauth_token) + q.delete(synchronize_session=False) + db.session.commit() + + def create_authorization_verifier(credential, grant_user, verifier): + credential.user_id = grant_user.id # assuming your end user model has `.id` + credential.oauth_verifier = verifier + db.session.add(credential) + db.session.commit() + return credential + + server.register_hook('create_temporary_credential', create_temporary_credential) + server.register_hook('get_temporary_credential', get_temporary_credential) + server.register_hook('delete_temporary_credential', delete_temporary_credential) + server.register_hook('create_authorization_verifier', create_authorization_verifier) + +For both cache and database temporary credential, Developers MUST register a +``create_token_credential`` hook:: + + def create_token_credential(token, temporary_credential): + credential = TokenCredential( + oauth_token=token['oauth_token'], + oauth_token_secret=token['oauth_token_secret'], + client_id=temporary_credential.get_client_id() + ) + credential.user_id = temporary_credential.user_id + db.session.add(credential) + db.session.commit() + return credential + + server.register_hook('create_token_credential', create_token_credential) Server Implementation @@ -238,4 +350,3 @@ the token credential:: @app.route('/token', methods=['POST']) def issue_token(): return server.create_token_response() - diff --git a/docs/flask/1/customize.rst b/docs/oauth1/provider/flask/customize.rst similarity index 100% rename from docs/flask/1/customize.rst rename to docs/oauth1/provider/flask/customize.rst diff --git a/docs/flask/1/index.rst b/docs/oauth1/provider/flask/index.rst similarity index 86% rename from docs/flask/1/index.rst rename to docs/oauth1/provider/flask/index.rst index f43bf3068..ca014a365 100644 --- a/docs/flask/1/index.rst +++ b/docs/oauth1/provider/flask/index.rst @@ -1,7 +1,7 @@ .. _flask_oauth1_server: -Flask OAuth 1.0 Server -====================== +Flask Integration +================= .. meta:: :description: How to create an OAuth 1.0 server in Flask with Authlib. @@ -22,7 +22,7 @@ At the very beginning, we need to have some basic understanding of export AUTHLIB_INSECURE_TRANSPORT=true -Looking for Flask OAuth 1.0 client? Check out :ref:`flask_client`. +Looking for Flask OAuth 1.0 client? Check out :ref:`flask_oauth1_client`. .. toctree:: :maxdepth: 2 diff --git a/docs/oauth1/provider/flask/resource-server.rst b/docs/oauth1/provider/flask/resource-server.rst new file mode 100644 index 000000000..f5be05830 --- /dev/null +++ b/docs/oauth1/provider/flask/resource-server.rst @@ -0,0 +1,106 @@ +Resource Servers +================ + +.. versionchanged:: v1.0.0 + We have removed built-in SQLAlchemy integrations. + +Protect users resources, so that only the authorized clients with the +authorized access token can access the given scope resources. + +A resource server can be a different server other than the authorization +server. Here is the way to protect your users' resources:: + + from flask import jsonify + from authlib.integrations.flask_oauth1 import ResourceProtector, current_credential + + # we will define ``query_client``, ``query_token``, and ``exists_nonce`` later. + require_oauth = ResourceProtector( + app, query_client=query_client, + query_token=query_token, + exists_nonce=exists_nonce, + ) + # or initialize it lazily + require_oauth = ResourceProtector() + require_oauth.init_app( + app, + query_client=query_client, + query_token=query_token, + exists_nonce=exists_nonce, + ) + + @app.route('/user') + @require_oauth + def user_profile(): + user = current_credential.user + return jsonify(user) + +The ``current_credential`` is a proxy to the Token model you have defined above. +Since there is a ``user`` relationship on the Token model, we can access this +``user`` with ``current_credential.user``. + +Initialize +---------- + +To initialize ``ResourceProtector``, we need three functions: + +1. query_client +2. query_token +3. exists_nonce + +If using SQLAlchemy, the ``query_client`` could be:: + + def query_client(client_id): + # assuming ``Client`` is the model + return Client.query.filter_by(client_id=client_id).first() + +And ``query_token`` would be:: + + def query_token(client_id, oauth_token): + return TokenCredential.query.filter_by(client_id=client_id, oauth_token=oauth_token).first() + +For ``exists_nonce``, if you are using cache now (as in authorization server), Authlib +has a built-in tool function:: + + from authlib.integrations.flask_oauth1 import create_exists_nonce_func + exists_nonce = create_exists_nonce_func(cache) + +If using database, with SQLAlchemy it would look like:: + + def exists_nonce(nonce, timestamp, client_id, oauth_token): + q = db.session.query(TimestampNonce.nonce).filter_by( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + ) + if oauth_token: + q = q.filter_by(oauth_token=oauth_token) + rv = q.first() + if rv: + return True + + tn = TimestampNonce( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + oauth_token=oauth_token, + ) + db.session.add(tn) + db.session.commit() + return False + +MethodView & Flask-Restful +-------------------------- + +You can also use the ``require_oauth`` decorator in ``flask.views.MethodView`` +and ``flask_restful.Resource``:: + + from flask.views import MethodView + + class UserAPI(MethodView): + decorators = [require_oauth] + + + from flask_restful import Resource + + class UserAPI(Resource): + method_decorators = [require_oauth] diff --git a/docs/oauth1/provider/index.rst b/docs/oauth1/provider/index.rst new file mode 100644 index 000000000..32b572d34 --- /dev/null +++ b/docs/oauth1/provider/index.rst @@ -0,0 +1,8 @@ +Provider +======== + +.. toctree:: + :maxdepth: 2 + + flask/index + django/index diff --git a/docs/oauth1/specs/index.rst b/docs/oauth1/specs/index.rst new file mode 100644 index 000000000..312c321ec --- /dev/null +++ b/docs/oauth1/specs/index.rst @@ -0,0 +1,7 @@ +Specifications +============== + +.. toctree:: + :maxdepth: 1 + + rfc5849 diff --git a/docs/specs/rfc5849.rst b/docs/oauth1/specs/rfc5849.rst similarity index 100% rename from docs/specs/rfc5849.rst rename to docs/oauth1/specs/rfc5849.rst diff --git a/docs/django/2/api.rst b/docs/oauth2/authorization-server/django/api.rst similarity index 86% rename from docs/django/2/api.rst rename to docs/oauth2/authorization-server/django/api.rst index 7fddd7cf0..0f37c221d 100644 --- a/docs/django/2/api.rst +++ b/docs/oauth2/authorization-server/django/api.rst @@ -1,5 +1,5 @@ -API References of Django OAuth 2.0 Server -========================================= +Reference +========= This part of the documentation covers the interface of Django OAuth 2.0 Server. @@ -10,7 +10,7 @@ Server. :members: register_grant, register_endpoint, - validate_consent_request, + get_consent_grant, create_authorization_response, create_token_response, create_endpoint_response diff --git a/docs/django/2/authorization-server.rst b/docs/oauth2/authorization-server/django/authorization-server.rst similarity index 89% rename from docs/django/2/authorization-server.rst rename to docs/oauth2/authorization-server/django/authorization-server.rst index 4e7e0fe83..9424d2430 100644 --- a/docs/django/2/authorization-server.rst +++ b/docs/oauth2/authorization-server/django/authorization-server.rst @@ -67,14 +67,14 @@ the missing methods of :class:`~authlib.oauth2.rfc6749.ClientMixin`:: return True return redirect_uri in self.redirect_uris - def has_client_secret(self): - return bool(self.client_secret) - def check_client_secret(self, client_secret): return self.client_secret == client_secret - def check_token_endpoint_auth_method(self, method): - return self.token_endpoint_auth_method == method + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + return self.token_endpoint_auth_method == method + # TODO: developers can update this check method + return True def check_response_type(self, response_type): allowed = self.response_type.split() @@ -151,17 +151,22 @@ The ``AuthorizationServer`` has provided built-in methods to handle these endpoi # use ``server.create_authorization_response`` to handle authorization endpoint def authorize(request): + try: + grant = server.get_consent_grant(request, end_user=request.user) + except OAuth2Error as error: + return server.handle_error_response(request, error) + if request.method == 'GET': - grant = server.validate_consent_request(request, end_user=request.user) - context = dict(grant=grant, user=request.user) + scope = grant.client.get_allowed_scope(grant.request.payload.scope) + context = dict(grant=grant, client=grant.client, scope=scope, user=request.user) return render(request, 'authorize.html', context) if is_user_confirmed(request): # granted by resource owner - return server.create_authorization_response(request, grant_user=request.user) + return server.create_authorization_response(request, grant=grant, grant_user=request.user) # denied by resource owner - return server.create_authorization_response(request, grant_user=None) + return server.create_authorization_response(request, grant=grant, grant_user=None) # use ``server.create_token_response`` to handle token endpoint diff --git a/docs/django/2/endpoints.rst b/docs/oauth2/authorization-server/django/endpoints.rst similarity index 100% rename from docs/django/2/endpoints.rst rename to docs/oauth2/authorization-server/django/endpoints.rst diff --git a/docs/django/2/grants.rst b/docs/oauth2/authorization-server/django/grants.rst similarity index 97% rename from docs/django/2/grants.rst rename to docs/oauth2/authorization-server/django/grants.rst index 8c432644a..e0c5312ea 100644 --- a/docs/django/2/grants.rst +++ b/docs/oauth2/authorization-server/django/grants.rst @@ -67,8 +67,8 @@ grant type. Here is how:: code=code, client_id=client.client_id, redirect_uri=request.redirect_uri, - response_type=request.response_type, - scope=request.scope, + response_type=request.payload.response_type, + scope=request.payload.scope, user=request.user, ) auth_code.save() @@ -76,8 +76,8 @@ grant type. Here is how:: def query_authorization_code(self, code, client): try: - item = OAuth2Code.objects.get(code=code, client_id=client.client_id) - except OAuth2Code.DoesNotExist: + item = AuthorizationCode.objects.get(code=code, client_id=client.client_id) + except AuthorizationCode.DoesNotExist: return None if not item.is_expired(): diff --git a/docs/django/2/index.rst b/docs/oauth2/authorization-server/django/index.rst similarity index 87% rename from docs/django/2/index.rst rename to docs/oauth2/authorization-server/django/index.rst index 43b8927e2..a33113984 100644 --- a/docs/django/2/index.rst +++ b/docs/oauth2/authorization-server/django/index.rst @@ -1,7 +1,7 @@ .. _django_oauth2_server: -Django OAuth 2.0 Server -======================= +Django Integration +================== .. meta:: :description: How to create an OAuth 2.0 provider in Django with Authlib. @@ -25,6 +25,7 @@ At the very beginning, we need to have some basic understanding of export AUTHLIB_INSECURE_TRANSPORT=true Looking for Django OAuth 2.0 client? Check out :ref:`django_client`. +Looking for Django OAuth 2.0 resource server? Check out :ref:`django_resource_server`. .. toctree:: :maxdepth: 2 @@ -32,6 +33,5 @@ Looking for Django OAuth 2.0 client? Check out :ref:`django_client`. authorization-server grants endpoints - resource-server openid-connect api diff --git a/docs/django/2/openid-connect.rst b/docs/oauth2/authorization-server/django/openid-connect.rst similarity index 80% rename from docs/django/2/openid-connect.rst rename to docs/oauth2/authorization-server/django/openid-connect.rst index fe98140a3..c7729ad21 100644 --- a/docs/django/2/openid-connect.rst +++ b/docs/oauth2/authorization-server/django/openid-connect.rst @@ -106,31 +106,35 @@ extended features. We can apply the :class:`OpenIDCode` extension to First, we need to implement the missing methods for ``OpenIDCode``:: + from joserfc.jwk import KeySet from authlib.oidc.core import grants, UserInfo class OpenIDCode(grants.OpenIDCode): + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + + def get_client_claims(self, client): + return { + 'iss': 'https://example.com', + } + def exists_nonce(self, nonce, request): try: AuthorizationCode.objects.get( - client_id=request.client_id, nonce=nonce + client_id=request.payload.client_id, nonce=nonce ) return True except AuthorizationCode.DoesNotExist: return False - def get_jwt_config(self, grant): - return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', - 'iss': 'https://example.com', - 'exp': 3600 - } - def generate_user_info(self, user, scope): - user_info = UserInfo(sub=str(user.pk), name=user.name) - if 'email' in scope: - user_info['email'] = user.email - return user_info + return UserInfo( + sub=str(user.pk), + name=user.name, + email=user.email, + ).filter(scope) Second, since there is one more ``nonce`` value in ``AuthorizationCode`` data, we need to save this value into database. In this case, we have to update our @@ -139,13 +143,13 @@ we need to save this value into database. In this case, we have to update our class AuthorizationCodeGrant(_AuthorizationCodeGrant): def save_authorization_code(self, code, request): # openid request MAY have "nonce" parameter - nonce = request.data.get('nonce') + nonce = request.payload.data.get('nonce') client = request.client auth_code = AuthorizationCode( code=code, client_id=client.client_id, redirect_uri=request.redirect_uri, - scope=request.scope, + scope=request.payload.scope, user=request.user, nonce=nonce, ) @@ -186,31 +190,35 @@ The Implicit Flow is mainly used by Clients implemented in a browser using a scripting language. You need to implement the missing methods of :class:`OpenIDImplicitGrant` before register it:: + from joserfc.jwk import KeySet from authlib.oidc.core import grants class OpenIDImplicitGrant(grants.OpenIDImplicitGrant): + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + + def get_client_claims(self, client): + return { + 'iss': 'https://example.com', + } + def exists_nonce(self, nonce, request): try: AuthorizationCode.objects.get( - client_id=request.client_id, nonce=nonce) + client_id=request.payload.client_id, nonce=nonce) ) return True except AuthorizationCode.DoesNotExist: return False - def get_jwt_config(self): - return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', - 'iss': 'https://example.com', - 'exp': 3600 - } - def generate_user_info(self, user, scope): - user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email - return user_info + return UserInfo( + sub=str(user.pk), + name=user.name, + email=user.email, + ).filter(scope) server.register_grant(OpenIDImplicitGrant) @@ -226,18 +234,29 @@ OpenIDHybridGrant is a subclass of OpenIDImplicitGrant, so the missing methods are the same, except that OpenIDHybridGrant has one more missing method, that is ``save_authorization_code``. You can implement it like this:: + from joserfc.jwk import KeySet from authlib.oidc.core import grants class OpenIDHybridGrant(grants.OpenIDHybridGrant): + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + + def get_client_claims(self, client): + return { + 'iss': 'https://example.com', + } + def save_authorization_code(self, code, request): # openid request MAY have "nonce" parameter - nonce = request.data.get('nonce') + nonce = request.payload.data.get('nonce') client = request.client auth_code = AuthorizationCode( code=code, client_id=client.client_id, redirect_uri=request.redirect_uri, - scope=request.scope, + scope=request.payload.scope, user=request.user, nonce=nonce, ) @@ -247,29 +266,23 @@ is ``save_authorization_code``. You can implement it like this:: def exists_nonce(self, nonce, request): try: AuthorizationCode.objects.get( - client_id=request.client_id, nonce=nonce) + client_id=request.payload.client_id, nonce=nonce) ) return True except AuthorizationCode.DoesNotExist: return False - def get_jwt_config(self): - return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', - 'iss': 'https://example.com', - 'exp': 3600 - } - def generate_user_info(self, user, scope): - user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email - return user_info + return UserInfo( + sub=str(user.pk), + name=user.name, + email=user.email, + ).filter(scope) # register it to grant endpoint server.register_grant(OpenIDHybridGrant) -Since all OpenID Connect Flow requires ``exists_nonce``, ``get_jwt_config`` -and ``generate_user_info`` methods, you can create shared functions for them. +Since all OpenID Connect Flow require ``exists_nonce``, ``resolve_client_private_key``, +``get_client_claims``, ``get_client_algorithm`` and ``generate_user_info`` methods, +you can create shared functions for them. diff --git a/docs/flask/2/api.rst b/docs/oauth2/authorization-server/flask/api.rst similarity index 88% rename from docs/flask/2/api.rst rename to docs/oauth2/authorization-server/flask/api.rst index 93089d181..b4c1db97f 100644 --- a/docs/flask/2/api.rst +++ b/docs/oauth2/authorization-server/flask/api.rst @@ -1,5 +1,5 @@ -API References of Flask OAuth 2.0 Server -======================================== +Reference +========= This part of the documentation covers the interface of Flask OAuth 2.0 Server. @@ -10,9 +10,8 @@ Server. :members: register_grant, register_endpoint, - create_token_expires_in_generator, create_bearer_token_generator, - validate_consent_request, + get_consent_grant, create_authorization_response, create_token_response, create_endpoint_response @@ -28,7 +27,7 @@ Server. from authlib.integrations.flask_oauth2 import current_token - @require_oauth() + @require_oauth @app.route('/user_id') def user_id(): # current token instance of the OAuth Token model diff --git a/docs/flask/2/authorization-server.rst b/docs/oauth2/authorization-server/flask/authorization-server.rst similarity index 88% rename from docs/flask/2/authorization-server.rst rename to docs/oauth2/authorization-server/flask/authorization-server.rst index 1ba8bd577..1807584b7 100644 --- a/docs/flask/2/authorization-server.rst +++ b/docs/oauth2/authorization-server/flask/authorization-server.rst @@ -55,7 +55,7 @@ Token .. note:: - Only Bearer Token is supported by now. MAC Token is still under drafts, + Only Bearer Token is supported for now. MAC Token is still under draft, it will be available when it goes into RFC. Tokens are used to access the users' resources. A token is issued with a @@ -82,7 +82,7 @@ A token is associated with a resource owner. There is no certain name for it, here we call it ``user``, but it can be anything else. If you decide to implement all the missing methods by yourself, get a deep -inside with :class:`~authlib.oauth2.rfc6749.TokenMixin` API reference. +inside the :class:`~authlib.oauth2.rfc6749.TokenMixin` API reference. Server ------ @@ -163,27 +163,40 @@ OAUTH2_ERROR_URIS A list of tuple for (``error``, ``error_uri`` Now define an endpoint for authorization. This endpoint is used by ``authorization_code`` and ``implicit`` grants:: + from authlib.oauth2 import OAuth2Error from flask import request, render_template from your_project.auth import current_user @app.route('/oauth/authorize', methods=['GET', 'POST']) def authorize(): + try: + grant = server.get_consent_grant(end_user=current_user) + except OAuth2Error as error: + return authorization.handle_error_response(request, error) + # Login is required since we need to know the current resource owner. # It can be done with a redirection to the login page, or a login # form on this authorization page. if request.method == 'GET': - grant = server.validate_consent_request(end_user=current_user) + scope = grant.client.get_allowed_scope(grant.request.payload.scope) + + # You may add a function to extract scope into a list of scopes + # with rich information, e.g. + scopes = describe_scope(scope) # returns [{'key': 'email', 'icon': '...'}] return render_template( 'authorize.html', grant=grant, user=current_user, + scopes=scopes, ) + confirmed = request.form['confirm'] if confirmed: # granted by resource owner - return server.create_authorization_response(grant_user=current_user) + return server.create_authorization_response(grant=grant, grant_user=current_user) + # denied by resource owner - return server.create_authorization_response(grant_user=None) + return server.create_authorization_response(grant=grant, grant_user=None) This is a simple demo, the real case should be more complex. There is a little more complex demo in https://github.com/authlib/example-oauth2-server. @@ -202,7 +215,7 @@ Register Error URIs ------------------- To create a better developer experience for debugging, it is suggested that -you creating some documentation for errors. Here is a list of built-in +you create some documentation for errors. Here is a list of built-in :ref:`specs/rfc6949-errors`. You can design a documentation page with a description of each error. For @@ -225,4 +238,4 @@ I18N on Errors ~~~~~~~~~~~~~~ It is also possible to add i18n support to the ``error_description``. The -feature has been implemented in version 0.8, but there are still work to do. +feature has been implemented in version 0.8, but there is still work to do. diff --git a/docs/flask/2/endpoints.rst b/docs/oauth2/authorization-server/flask/endpoints.rst similarity index 100% rename from docs/flask/2/endpoints.rst rename to docs/oauth2/authorization-server/flask/endpoints.rst diff --git a/docs/flask/2/grants.rst b/docs/oauth2/authorization-server/flask/grants.rst similarity index 86% rename from docs/flask/2/grants.rst rename to docs/oauth2/authorization-server/flask/grants.rst index 61f408b16..9fe03bc07 100644 --- a/docs/flask/2/grants.rst +++ b/docs/oauth2/authorization-server/flask/grants.rst @@ -14,8 +14,8 @@ Authorization Code Grant Authorization Code Grant is a very common grant type, it is supported by almost every OAuth 2 providers. It uses an authorization code to exchange access -token. In this case, we need a place to store the authorization code. It can be -kept in a database or a cache like redis. Here is a SQLAlchemy mixin for +tokens. In this case, we need a place to store the authorization code. It can +be kept in a database or a cache like redis. Here is a SQLAlchemy mixin for **AuthorizationCode**:: from authlib.integrations.sqla_oauth2 import OAuth2AuthorizationCodeMixin @@ -27,7 +27,7 @@ kept in a database or a cache like redis. Here is a SQLAlchemy mixin for ) user = db.relationship('User') -Implement this grant by subclass :class:`AuthorizationCodeGrant`:: +Implement this grant by subclassing :class:`AuthorizationCodeGrant`:: from authlib.oauth2.rfc6749 import grants @@ -38,7 +38,7 @@ Implement this grant by subclass :class:`AuthorizationCodeGrant`:: code=code, client_id=client.client_id, redirect_uri=request.redirect_uri, - scope=request.scope, + scope=request.payload.scope, user_id=request.user.id, ) db.session.add(auth_code) @@ -80,7 +80,7 @@ Implicit Grant -------------- The implicit grant type is usually used in a browser, when resource -owner granted the access, access token is issued in the redirect URI, +owner granted the access, an access token is issued in the redirect URI, there is no missing implementation, which means it can be easily registered with:: @@ -89,15 +89,16 @@ with:: # register it to grant endpoint server.register_grant(grants.ImplicitGrant) -Implicit Grant is used by **public** client which has no **client_secret**. -Only allowed :ref:`client_auth_methods`: ``none``. +Implicit Grant is used by **public** clients which have no **client_secret**. +Default allowed :ref:`client_auth_methods`: ``none``. Resource Owner Password Credentials Grant ----------------------------------------- -Resource owner uses their username and password to exchange an access token, -this grant type should be used only when the client is trustworthy, implement -it with a subclass of :class:`ResourceOwnerPasswordCredentialsGrant`:: +The resource owner uses its username and password to exchange an access +token. This grant type should be used only when the client is trustworthy; +implement it with a subclass of +:class:`ResourceOwnerPasswordCredentialsGrant`:: from authlib.oauth2.rfc6749 import grants @@ -142,8 +143,8 @@ You can add more in the subclass:: Refresh Token Grant ------------------- -Many OAuth 2 providers haven't implemented refresh token endpoint. Authlib -provides it as a grant type, implement it with a subclass of +Many OAuth 2 providers do not implement a refresh token endpoint. Authlib +provides it as a grant type; implement it with a subclass of :class:`RefreshTokenGrant`:: from authlib.oauth2.rfc6749 import grants @@ -192,10 +193,6 @@ supports two endpoints: 1. Authorization Endpoint: which can handle requests with ``response_type``. 2. Token Endpoint: which is the endpoint to issue tokens. -.. versionchanged:: v0.12 - Using ``AuthorizationEndpointMixin`` and ``TokenEndpointMixin`` instead of - ``AUTHORIZATION_ENDPOINT=True`` and ``TOKEN_ENDPOINT=True``. - Creating a custom grant type with **BaseGrant**:: from authlib.oauth2.rfc6749.grants import ( @@ -231,12 +228,12 @@ Grant Extensions .. versionadded:: 0.10 -Grant can accept extensions. Developers can pass extensions when registering -grant:: +Grants can accept extensions. Developers can pass extensions when registering +grants:: authorization_server.register_grant(AuthorizationCodeGrant, [extension]) -For instance, there is ``CodeChallenge`` extension in Authlib:: +For instance, there is the ``CodeChallenge`` extension in Authlib:: server.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=False)]) diff --git a/docs/flask/2/index.rst b/docs/oauth2/authorization-server/flask/index.rst similarity index 90% rename from docs/flask/2/index.rst rename to docs/oauth2/authorization-server/flask/index.rst index eb7cf8d0c..cc2a695b5 100644 --- a/docs/flask/2/index.rst +++ b/docs/oauth2/authorization-server/flask/index.rst @@ -1,7 +1,7 @@ .. _flask_oauth2_server: -Flask OAuth 2.0 Server -====================== +Flask Integration +================= .. meta:: :description: How to create an OAuth 2.0 provider in Flask with Authlib. @@ -30,6 +30,7 @@ At the very beginning, we need to have some basic understanding of export AUTHLIB_INSECURE_TRANSPORT=true Looking for Flask OAuth 2.0 client? Check out :ref:`flask_client`. +Looking for Flask OAuth 2.0 resource server? Check out :ref:`flask_oauth2_resource_protector`. .. toctree:: :maxdepth: 2 @@ -37,6 +38,5 @@ Looking for Flask OAuth 2.0 client? Check out :ref:`flask_client`. authorization-server grants endpoints - resource-server openid-connect api diff --git a/docs/flask/2/openid-connect.rst b/docs/oauth2/authorization-server/flask/openid-connect.rst similarity index 69% rename from docs/flask/2/openid-connect.rst rename to docs/oauth2/authorization-server/flask/openid-connect.rst index e24243e24..4e7c62142 100644 --- a/docs/flask/2/openid-connect.rst +++ b/docs/oauth2/authorization-server/flask/openid-connect.rst @@ -15,11 +15,6 @@ Since OpenID Connect is built on OAuth 2.0 frameworks, you need to read .. module:: authlib.oauth2.rfc6749.grants :noindex: -.. versionchanged:: v0.12 - - The Grant system has been redesigned from v0.12. This documentation ONLY - works for Authlib >=v0.12. - Looking for OpenID Connect Client? Head over to :ref:`flask_client`. Understand JWT @@ -28,7 +23,7 @@ Understand JWT OpenID Connect 1.0 uses JWT a lot. Make sure you have the basic understanding of :ref:`jose`. -For OpenID Connect, we need to understand at lease four concepts: +For OpenID Connect, we need to understand at least four concepts: 1. **alg**: Algorithm for JWT 2. **key**: Private key for JWT @@ -67,7 +62,7 @@ secrets between server and client. Most OpenID Connect services are using key ~~~ -A private key is required to generate JWT. The key that you are going to use +A private key is required to generate a JWT. The key that you are going to use dependents on the ``alg`` you are using. For instance, the alg is ``RS256``, you need to use an RSA private key. It can be set with:: @@ -79,8 +74,8 @@ you need to use an RSA private key. It can be set with:: iss ~~~ -The ``iss`` value in JWT payload. The value can be your website name or URL. -For example, Google is using:: +The ``iss`` value in the JWT payload. The value can be your website name or +URL. For example, Google is using:: {"iss": "https://accounts.google.com"} @@ -98,42 +93,46 @@ extended features. We can apply the :class:`OpenIDCode` extension to First, we need to implement the missing methods for ``OpenIDCode``:: + from joserfc.jwk import KeySet from authlib.oidc.core import grants, UserInfo class OpenIDCode(grants.OpenIDCode): - def exists_nonce(self, nonce, request): - exists = AuthorizationCode.query.filter_by( - client_id=request.client_id, nonce=nonce - ).first() - return bool(exists) + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) - def get_jwt_config(self, grant): + def get_client_claims(self, client): return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', 'iss': 'https://example.com', - 'exp': 3600 } + def exists_nonce(self, nonce, request): + exists = AuthorizationCode.query.filter_by( + client_id=request.payload.client_id, nonce=nonce + ).first() + return bool(exists) + def generate_user_info(self, user, scope): - user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email - return user_info + return UserInfo( + sub=user.id, + name=user.name, + email=user.email, + ).filter(scope) -Second, since there is one more ``nonce`` value in ``AuthorizationCode`` data, -we need to save this value into database. In this case, we have to update our -:ref:`flask_oauth2_code_grant` ``save_authorization_code`` method:: +Second, since there is one more ``nonce`` value in the ``AuthorizationCode`` +data, we need to save this value into the database. In this case, we have to +update our :ref:`flask_oauth2_code_grant` ``save_authorization_code`` method:: class AuthorizationCodeGrant(_AuthorizationCodeGrant): def save_authorization_code(self, code, request): # openid request MAY have "nonce" parameter - nonce = request.data.get('nonce') + nonce = request.payload.data.get('nonce') auth_code = AuthorizationCode( code=code, client_id=request.client.client_id, redirect_uri=request.redirect_uri, - scope=request.scope, + scope=request.payload.scope, user_id=request.user.id, nonce=nonce, ) @@ -143,14 +142,14 @@ we need to save this value into database. In this case, we have to update our # ... -Finally, you can register ``AuthorizationCodeGrant`` with ``OpenIDCode`` +Finally, you can register ``AuthorizationCodeGrant`` with the ``OpenIDCode`` extension:: # register it to grant endpoint server.register_grant(AuthorizationCodeGrant, [OpenIDCode(require_nonce=True)]) The difference between OpenID Code flow and the standard code flow is that -OpenID Connect request has a scope of "openid": +OpenID Connect requests have a scope of "openid": .. code-block:: http @@ -176,30 +175,34 @@ Implicit Flow The Implicit Flow is mainly used by Clients implemented in a browser using a scripting language. You need to implement the missing methods of -:class:`OpenIDImplicitGrant` before register it:: +:class:`OpenIDImplicitGrant` before registering it:: + from joserfc.jwk import KeySet from authlib.oidc.core import grants class OpenIDImplicitGrant(grants.OpenIDImplicitGrant): - def exists_nonce(self, nonce, request): - exists = AuthorizationCode.query.filter_by( - client_id=request.client_id, nonce=nonce - ).first() - return bool(exists) + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) - def get_jwt_config(self): + def get_client_claims(self, client): return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', 'iss': 'https://example.com', - 'exp': 3600 } + def exists_nonce(self, nonce, request): + exists = AuthorizationCode.query.filter_by( + client_id=request.payload.client_id, nonce=nonce + ).first() + return bool(exists) + def generate_user_info(self, user, scope): - user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email - return user_info + return UserInfo( + sub=user.id, + name=user.name, + email=user.email, + ).filter(scope) server.register_grant(OpenIDImplicitGrant) @@ -208,25 +211,39 @@ a scripting language. You need to implement the missing methods of Hybrid Flow ------------ -Hybrid flow is a mix of the code flow and implicit flow. You only need to -implement the authorization endpoint part, token endpoint will be handled +The Hybrid flow is a mix of code flow and implicit flow. You only need to +implement the authorization endpoint part, as token endpoint will be handled by Authorization Code Flow. OpenIDHybridGrant is a subclass of OpenIDImplicitGrant, so the missing methods are the same, except that OpenIDHybridGrant has one more missing method, that is ``save_authorization_code``. You can implement it like this:: + from joserfc.jwk import KeySet from authlib.oidc.core import grants from authlib.common.security import generate_token class OpenIDHybridGrant(grants.OpenIDHybridGrant): + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + + def get_client_claims(self, client): + return { + 'iss': 'https://example.com', + } + + def get_client_algorithm(self, client): + return 'RS512' + def save_authorization_code(self, code, request): - nonce = request.data.get('nonce') + nonce = request.payload.data.get('nonce') item = AuthorizationCode( code=code, client_id=request.client.client_id, redirect_uri=request.redirect_uri, - scope=request.scope, + scope=request.payload.scope, user_id=request.user.id, nonce=nonce, ) @@ -236,29 +253,23 @@ is ``save_authorization_code``. You can implement it like this:: def exists_nonce(self, nonce, request): exists = AuthorizationCode.query.filter_by( - client_id=request.client_id, nonce=nonce + client_id=request.payload.client_id, nonce=nonce ).first() return bool(exists) - def get_jwt_config(self): - return { - 'key': read_private_key_file(key_path), - 'alg': 'RS512', - 'iss': 'https://example.com', - 'exp': 3600 - } - def generate_user_info(self, user, scope): - user_info = UserInfo(sub=user.id, name=user.name) - if 'email' in scope: - user_info['email'] = user.email - return user_info + return UserInfo( + sub=user.id, + name=user.name, + email=user.email, + ).filter(scope) # register it to grant endpoint server.register_grant(OpenIDHybridGrant) -Since all OpenID Connect Flow requires ``exists_nonce``, ``get_jwt_config`` -and ``generate_user_info`` methods, you can create shared functions for them. +Since all OpenID Connect Flow require ``exists_nonce``, ``resolve_client_private_key``, +``get_client_claims``, ``get_client_algorithm`` and ``generate_user_info`` methods, +you can create shared functions for them. Find the `example of OpenID Connect server `_. diff --git a/docs/oauth2/authorization-server/index.rst b/docs/oauth2/authorization-server/index.rst new file mode 100644 index 000000000..5e70005c3 --- /dev/null +++ b/docs/oauth2/authorization-server/index.rst @@ -0,0 +1,44 @@ +.. _authorization_server: + +Authorization Server +==================== + +An Authorization Server is the component that authenticates users and issues +access tokens to clients. Build this when you want to run your own OAuth 2.0 +or OpenID Connect provider. + +Not sure this is the right role? See :ref:`intro_oauth2` for an overview of +all OAuth 2.0 roles. + +Looking for the :ref:`resource_server` (protecting an API)? +Or the :ref:`oauth_client` (consuming an OAuth provider)? + +Understand +---------- + +Before implementing, read the concept guides: + +* :ref:`intro_oauth2` — OAuth 2.0 roles, flows, and grant types + +How-to +------ + +OAuth 2.0 +~~~~~~~~~ + +.. toctree:: + :maxdepth: 2 + + flask/index + django/index + +Reference +--------- + +Relevant specifications: + +* :doc:`../specs/rfc6749` — The OAuth 2.0 Authorization Framework +* :doc:`../specs/rfc7636` — PKCE +* :doc:`../specs/rfc7591` — Dynamic Client Registration +* :doc:`../specs/rfc8414` — Authorization Server Metadata +* :doc:`../specs/oidc` — OpenID Connect Core diff --git a/docs/oauth2/client/http/api.rst b/docs/oauth2/client/http/api.rst new file mode 100644 index 000000000..319868cdd --- /dev/null +++ b/docs/oauth2/client/http/api.rst @@ -0,0 +1,63 @@ +Reference +========= + +.. meta:: + :description: API references on Authlib OAuth 2.0 HTTP session clients. + +Requests OAuth 2.0 +------------------ + +.. module:: authlib.integrations.requests_client + :no-index: + +.. autoclass:: OAuth2Session + :no-index: + :members: + register_client_auth_method, + create_authorization_url, + fetch_token, + refresh_token, + revoke_token, + introspect_token, + register_compliance_hook + +.. autoclass:: OAuth2Auth + :no-index: + +.. autoclass:: AssertionSession + :no-index: + + +HTTPX OAuth 2.0 +--------------- + +.. module:: authlib.integrations.httpx_client + :no-index: + +.. autoclass:: OAuth2Auth + :no-index: + +.. autoclass:: OAuth2Client + :no-index: + :members: + register_client_auth_method, + create_authorization_url, + fetch_token, + refresh_token, + revoke_token, + introspect_token, + register_compliance_hook + +.. autoclass:: AsyncOAuth2Client + :no-index: + :members: + register_client_auth_method, + create_authorization_url, + fetch_token, + refresh_token, + revoke_token, + introspect_token, + register_compliance_hook + +.. autoclass:: AsyncAssertionClient + :no-index: diff --git a/docs/client/httpx.rst b/docs/oauth2/client/http/httpx.rst similarity index 79% rename from docs/client/httpx.rst rename to docs/oauth2/client/http/httpx.rst index 48412f1f8..a5d87980e 100644 --- a/docs/client/httpx.rst +++ b/docs/oauth2/client/http/httpx.rst @@ -1,48 +1,31 @@ .. _httpx_client: -OAuth for HTTPX -=============== +OAuth 2.0 for HTTPX +=================== .. meta:: - :description: An OAuth 1.0 and OAuth 2.0 Client implementation for a next + :description: An OAuth 2.0 Client implementation for a next generation HTTP client for Python, including support for OpenID Connect and service account, powered by Authlib. .. module:: authlib.integrations.httpx_client :noindex: -HTTPX is a next-generation HTTP client for Python. Authlib enables OAuth 1.0 -and OAuth 2.0 for HTTPX with its async versions: +HTTPX is a next-generation HTTP client for Python. Authlib enables OAuth 2.0 +for HTTPX with its async versions: -* :class:`OAuth1Client` * :class:`OAuth2Client` * :class:`AssertionClient` -* :class:`AsyncOAuth1Client` * :class:`AsyncOAuth2Client` * :class:`AsyncAssertionClient` .. note:: HTTPX is still in its "alpha" stage, use it with caution. -HTTPX OAuth 1.0 ---------------- - -There are three steps in OAuth 1 to obtain an access token: - -1. fetch a temporary credential -2. visit the authorization page -3. exchange access token with the temporary credential - -It shares a common API design with :ref:`requests_client`. - -Read the common guide of :ref:`oauth_1_session` to understand the whole OAuth -1.0 flow. - - HTTPX OAuth 2.0 --------------- -In :ref:`oauth_2_session`, there are many grant types, including: +In :ref:`OAuth 2 Session `, there are many grant types, including: 1. Authorization Code Flow 2. Implicit Flow @@ -51,7 +34,7 @@ In :ref:`oauth_2_session`, there are many grant types, including: And also, Authlib supports non Standard OAuth 2.0 providers via Compliance Fix. -Read the common guide of :ref:`oauth_2_session` to understand the whole OAuth +Read the common guide of :ref:`OAuth 2 Session ` to understand the whole OAuth 2.0 flow. Using ``client_secret_jwt`` in HTTPX @@ -97,30 +80,11 @@ authentication method for HTTPX:: The ``PrivateKeyJWT`` is provided by :ref:`specs/rfc7523`. -Async OAuth 1.0 ---------------- - -The async version of :class:`AsyncOAuth1Client` works the same as -:ref:`oauth_1_session`, except that we need to add ``await`` when -required:: - - # fetching request token - request_token = await client.fetch_request_token(request_token_url) - - # fetching access token - access_token = await client.fetch_access_token(access_token_url) - - # normal requests - await client.get(...) - await client.post(...) - await client.put(...) - await client.delete(...) - Async OAuth 2.0 --------------- The async version of :class:`AsyncOAuth2Client` works the same as -:ref:`oauth_2_session`, except that we need to add ``await`` when +:ref:`OAuth 2 Session `, except that we need to add ``await`` when required:: # fetching access token diff --git a/docs/client/oauth2.rst b/docs/oauth2/client/http/index.rst similarity index 82% rename from docs/client/oauth2.rst rename to docs/oauth2/client/http/index.rst index 63a4a1fab..7ec5cbe1c 100644 --- a/docs/client/oauth2.rst +++ b/docs/oauth2/client/http/index.rst @@ -1,7 +1,7 @@ .. _oauth_2_session: -OAuth 2 Session -=============== +HTTP Clients +============ .. meta:: :description: An OAuth 2.0 Client implementation for Python requests, @@ -10,10 +10,6 @@ OAuth 2 Session .. module:: authlib.integrations :noindex: -.. versionchanged:: v0.13 - - All client related code have been moved into ``authlib.integrations``. For - earlier versions of Authlib, check out their own versions documentation. This documentation covers the common design of a Python OAuth 2.0 client. Authlib provides three implementations of OAuth 2.0 client: @@ -26,12 +22,15 @@ Authlib provides three implementations of OAuth 2.0 client: :class:`requests_client.OAuth2Session` and :class:`httpx_client.AsyncOAuth2Client` shares the same API. -There are also frameworks integrations of :ref:`flask_client`, :ref:`django_client` -and :ref:`starlette_client`. If you are using these frameworks, you may have interests -in their own documentation. - If you are not familiar with OAuth 2.0, it is better to read :ref:`intro_oauth2` now. +.. toctree:: + :maxdepth: 1 + + requests + httpx + api + OAuth2Session for Authorization Code ------------------------------------ @@ -95,7 +94,7 @@ the state in case of CSRF attack:: Save this token to access users' protected resources. -In real project, this session can not be re-used since you are redirected to +In real project, this session can not be reused since you are redirected to another website. You need to create another session yourself:: >>> state = restore_previous_state() @@ -106,12 +105,25 @@ another website. You need to create another session yourself:: >>> >>> # using httpx >>> from authlib.integrations.httpx_client import AsyncOAuth2Client - >>> client = OAuth2Client(client_id, client_secret, state=state) + >>> client = AsyncOAuth2Client(client_id, client_secret, state=state) >>> - >>> client.fetch_token(token_endpoint, authorization_response=authorization_response) + >>> await client.fetch_token(token_endpoint, authorization_response=authorization_response) Authlib has a built-in Flask/Django integration. Learn from them. +Add PKCE for Authorization Code +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Authlib client can handle PKCE automatically, just pass ``code_verifier`` to ``create_authorization_url`` +and ``fetch_token``:: + + >>> client = OAuth2Session(..., code_challenge_method='S256') + >>> code_verifier = generate_token(48) + >>> uri, state = client.create_authorization_url(authorization_endpoint, code_verifier=code_verifier) + >>> # ... + >>> token = client.fetch_token(..., code_verifier=code_verifier) + + OAuth2Session for Implicit -------------------------- @@ -122,7 +134,7 @@ the ``response_type`` of ``token``:: >>> print(uri) https://some-service.com/oauth/authorize?response_type=token&client_id=be..4d&... -Visit this link, and grant the authorization, the OAuth authoirzation server will +Visit this link, and grant the authorization, the OAuth authorization server will redirect back to your redirect_uri, the response url would be something like:: https://example.com/cb#access_token=2..WpA&state=xyz&token_type=bearer&expires_in=3600 @@ -190,7 +202,7 @@ These two methods are defined by RFC7523 and OpenID Connect. Find more in :ref:`jwt_oauth2session`. There are still cases that developers need to define a custom client -authentication method. Take :gh:`issue#158` as an example, the provider +authentication method. Take :issue:`158` as an example, the provider requires us put ``client_id`` and ``client_secret`` on URL when sending POST request:: @@ -218,12 +230,18 @@ is how we can get our OAuth 2.0 client authenticated:: client.register_client_auth_method(('client_secret_uri', auth_client_secret_uri)) With ``client_secret_uri`` registered, OAuth 2.0 client will authenticate with -the signed URI. +the signed URI. It is also possible to assign the function to ``token_endpoint_auth_method`` +directly:: + + client = OAuth2Session( + 'client_id', 'client_secret', + token_endpoint_auth_method=auth_client_secret_uri, + ) Access Protected Resources -------------------------- -Now you can access the protected resources. If you re-use the session, you +Now you can access the protected resources. If you reuse the session, you don't need to do anything:: >>> account_url = 'https://api.github.com/user' @@ -251,6 +269,23 @@ protected resources. In this case, we can refresh the token manually, or even better, Authlib will refresh the token automatically and update the token for us. +Automatically refreshing tokens +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If your :class:`~requests_client.OAuth2Session` class was created with the +`token_endpoint` parameter, Authlib will automatically refresh the token when +it has expired:: + + >>> openid_configuration = requests.get("https://example.org/.well-known/openid-configuration").json() + >>> session = OAuth2Session(…, token_endpoint=openid_configuration["token_endpoint"]) + +By default, the token will be refreshed 60 seconds before its actual expiry time, to avoid clock skew issues. +You can control this behaviour by setting the ``leeway`` parameter of the :class:`~requests_client.OAuth2Session` +class. + +Manually refreshing tokens +~~~~~~~~~~~~~~~~~~~~~~~~~~ + To call :meth:`~requests_client.OAuth2Session.refresh_token` manually means we are going to exchange a new "access_token" with "refresh_token":: @@ -287,6 +322,25 @@ call this our defined ``update_token`` to save the new token:: # if the token is expired, this GET request will update token client.get('https://openidconnect.googleapis.com/v1/userinfo') +Revoke and Introspect Token +--------------------------- + +If the provider support token revocation and introspection, you can revoke +and introspect the token with:: + + token_endpoint = 'https://example.com/oauth/token' + + token = get_your_previous_saved_token() + client.revoke_token(token_endpoint, token=token) + client.introspect_token(token_endpoint, token=token) + +You can find the available parameters in API docs: + +- :meth:`requests_client.OAuth2Session.revoke_token` +- :meth:`requests_client.OAuth2Session.introspect_token` +- :meth:`httpx_client.AsyncOAuth2Client.revoke_token` +- :meth:`httpx_client.AsyncOAuth2Client.introspect_token` + .. _compliance_fix_oauth2: Compliance Fix for non Standard @@ -361,15 +415,17 @@ This ``id_token`` is a JWT text, it can not be used unless it is parsed. Authlib has provided tools for parsing and validating OpenID Connect id_token:: >>> from authlib.oidc.core import CodeIDToken - >>> from authlib.jose import jwt + >>> from joserfc import jwt >>> # GET keys from https://www.googleapis.com/oauth2/v3/certs - >>> claims = jwt.decode(resp['id_token'], keys, claims_cls=CodeIDToken) + >>> token = jwt.decode(resp['id_token'], keys) + >>> claims = CodeIDToken(token.claims, token.header) >>> claims.validate() -Get deep inside with :class:`~authlib.jose.JsonWebToken` and -:class:`~authlib.oidc.core.CodeIDToken`. Learn how to validate JWT claims -at :ref:`jwt_guide`. +.. versionchanged:: 1.7 + We use joserfc_ for JWT encoding and decoding. Checkout the JWT guide + on https://jose.authlib.org/en/guide/jwt/ +.. _joserfc: https://jose.authlib.org/en/ .. _assertion_session: diff --git a/docs/client/requests.rst b/docs/oauth2/client/http/requests.rst similarity index 64% rename from docs/client/requests.rst rename to docs/oauth2/client/http/requests.rst index 815cfb8cb..bf57d74f4 100644 --- a/docs/client/requests.rst +++ b/docs/oauth2/client/http/requests.rst @@ -1,58 +1,24 @@ .. _requests_client: -OAuth for Requests -================== +OAuth 2.0 for Requests +====================== .. meta:: - :description: An OAuth 1.0 and OAuth 2.0 Client implementation for Python requests, + :description: An OAuth 2.0 Client implementation for Python requests, including support for OpenID Connect and service account, powered by Authlib. .. module:: authlib.integrations.requests_client :noindex: -Requests is a very popular HTTP library for Python. Authlib enables OAuth 1.0 -and OAuth 2.0 for Requests with its :class:`OAuth1Session`, :class:`OAuth2Session` -and :class:`AssertionSession`. - - -Requests OAuth 1.0 ------------------- - -There are three steps in :ref:`oauth_1_session` to obtain an access token: - -1. fetch a temporary credential -2. visit the authorization page -3. exchange access token with the temporary credential - -It shares a common API design with :ref:`httpx_client`. - -OAuth1Session -~~~~~~~~~~~~~ - -The requests integration follows our common guide of :ref:`oauth_1_session`. -Follow the documentation in :ref:`oauth_1_session` instead. - -OAuth1Auth -~~~~~~~~~~ - -It is also possible to use :class:`OAuth1Auth` directly with in requests. -After we obtained access token from an OAuth 1.0 provider, we can construct -an ``auth`` instance for requests:: - - auth = OAuth1Auth( - client_id='YOUR-CLIENT-ID', - client_secret='YOUR-CLIENT-SECRET', - token='oauth_token', - token_secret='oauth_token_secret', - ) - requests.get(url, auth=auth) +Requests is a very popular HTTP library for Python. Authlib enables OAuth 2.0 +for Requests with its :class:`OAuth2Session` and :class:`AssertionSession`. Requests OAuth 2.0 ------------------ -In :ref:`oauth_2_session`, there are many grant types, including: +In :ref:`OAuth 2 Session `, there are many grant types, including: 1. Authorization Code Flow 2. Implicit Flow @@ -61,7 +27,7 @@ In :ref:`oauth_2_session`, there are many grant types, including: And also, Authlib supports non Standard OAuth 2.0 providers via Compliance Fix. -Follow the common guide of :ref:`oauth_2_session` to find out how to use +Follow the common guide of :ref:`OAuth 2 Session ` to find out how to use requests integration of OAuth 2.0 flow. @@ -159,24 +125,13 @@ Self-Signed Certificate Self-signed certificate mutual-TLS method internet standard is defined in `RFC8705 Section 2.2`_ . -For specifics development purposes only, you may need to -**disable SSL verification**. - -You can force all requests to disable SSL verification by setting -your environment variable ``CURL_CA_BUNDLE=""``. +You can use the environment variables CURL_CA_BUNDLE and REQUESTS_CA_BUNDLE +to specify a CA certificate file for validating your self-signed certificate. -This solutions works because Python requests (and most of the packages) -overwrites the default value for ssl verifications from environment -variables ``CURL_CA_BUNDLE`` and ``REQUESTS_CA_BUNDLE``. +.. code-block:: bash -This hack will **only work** with ``CURL_CA_BUNDLE``, as you can see -in `requests/sessions.py`_ :: - - verify = (os.environ.get('REQUESTS_CA_BUNDLE') - or os.environ.get('CURL_CA_BUNDLE')) + REQUESTS_CA_BUNDLE=/path/to/ca-cert.pem Please remember to set the env variable only in you development environment. - .. _RFC8705 Section 2.2: https://tools.ietf.org/html/rfc8705#section-2.2 -.. _requests/sessions.py: https://github.com/requests/requests/blob/master/requests/sessions.py#L706 diff --git a/docs/oauth2/client/index.rst b/docs/oauth2/client/index.rst new file mode 100644 index 000000000..642138fc3 --- /dev/null +++ b/docs/oauth2/client/index.rst @@ -0,0 +1,39 @@ +.. _oauth_client: + +Client +====== + +.. meta:: + :description: Python OAuth 2.0 Client implementations with requests, HTTPX, + Flask, Django and Starlette, powered by Authlib. + +Authlib provides OAuth 2.0 client implementations for two distinct use cases: + +**HTTP Clients** — your Python code fetches tokens and calls APIs directly. +Suitable for scripts, CLIs, service-to-service communication:: + + from authlib.integrations.requests_client import OAuth2Session + + client = OAuth2Session(client_id, client_secret) + token = client.fetch_token(token_endpoint, ...) + resp = client.get('https://api.example.com/data') + +**Web Clients** — your web application delegates authentication to an OAuth 2.0 +provider. Works with any provider: well-known services (GitHub, Google…) or +your own authorization server. Integrations for Flask, Django, Starlette and +FastAPI:: + + from authlib.integrations.flask_client import OAuth + + oauth = OAuth(app) + github = oauth.register('github', {...}) + + @app.route('/login') + def login(): + return github.authorize_redirect(url_for('authorize', _external=True)) + +.. toctree:: + :maxdepth: 2 + + http/index + web/index diff --git a/docs/oauth2/client/web/api.rst b/docs/oauth2/client/web/api.rst new file mode 100644 index 000000000..968c41e4e --- /dev/null +++ b/docs/oauth2/client/web/api.rst @@ -0,0 +1,44 @@ +Reference +========= + +.. meta:: + :description: API references on Authlib OAuth 2.0 web framework client integrations. + +Flask Registry and RemoteApp +----------------------------- + +.. module:: authlib.integrations.flask_client + :no-index: + +.. autoclass:: OAuth + :no-index: + :members: + init_app, + register, + create_client + + +Django Registry and RemoteApp +------------------------------ + +.. module:: authlib.integrations.django_client + :no-index: + +.. autoclass:: OAuth + :no-index: + :members: + register, + create_client + + +Starlette Registry and RemoteApp +--------------------------------- + +.. module:: authlib.integrations.starlette_client + :no-index: + +.. autoclass:: OAuth + :no-index: + :members: + register, + create_client diff --git a/docs/oauth2/client/web/django.rst b/docs/oauth2/client/web/django.rst new file mode 100644 index 000000000..d10c861b0 --- /dev/null +++ b/docs/oauth2/client/web/django.rst @@ -0,0 +1,153 @@ +.. _django_client: + +Django Integration +================== + +.. meta:: + :description: The built-in Django integrations for OAuth 2.0 + clients, powered by Authlib. + +.. module:: authlib.integrations.django_client + :noindex: + +Looking for OAuth 2.0 server? + +- :ref:`django_oauth2_server` + +The Django client handles OAuth 2.0 services. Authlib has a shared API design +among framework integrations. Get started with :ref:`frameworks_clients`. + +Create a registry with :class:`OAuth` object:: + + from authlib.integrations.django_client import OAuth + + oauth = OAuth() + +The common use case for OAuth is authentication, e.g. let your users log in +with Twitter, GitHub, Google etc. + +.. important:: + + Please read :ref:`frameworks_clients` at first. Authlib has a shared API + design among framework integrations, learn them from :ref:`frameworks_clients`. + + +Configuration +------------- + +Authlib Django OAuth registry can load the configuration from your Django +application settings automatically. Every key value pair can be omitted. +They can be configured from your Django settings:: + + AUTHLIB_OAUTH_CLIENTS = { + 'github': { + 'client_id': 'GitHub Client ID', + 'client_secret': 'GitHub Client Secret', + 'access_token_url': 'https://github.com/login/oauth/access_token', + 'authorize_url': 'https://github.com/login/oauth/authorize', + 'api_base_url': 'https://api.github.com/', + 'client_kwargs': {'scope': 'user:email'}, + } + } + +Please check the parameters in ``.register`` in :ref:`frameworks_clients`. + +Routes for Authorization +------------------------ + +Just like the example in :ref:`frameworks_clients`, everything is the same. +But there is a hint to create ``redirect_uri`` with ``request`` in Django:: + + def login(request): + # build a full authorize callback uri + redirect_uri = request.build_absolute_uri('/authorize') + return oauth.twitter.authorize_redirect(request, redirect_uri) + + +Auto Update Token via Signal +---------------------------- + +Instead of defining an ``update_token`` method and passing it into OAuth registry, +it is also possible to use signals to listen for token updates:: + + from django.dispatch import receiver + from authlib.integrations.django_client import token_update + + @receiver(token_update) + def on_token_update(sender, name, token, refresh_token=None, access_token=None, **kwargs): + if refresh_token: + item = OAuth2Token.find(name=name, refresh_token=refresh_token) + elif access_token: + item = OAuth2Token.find(name=name, access_token=access_token) + else: + return + + # update old token + item.access_token = token['access_token'] + item.refresh_token = token.get('refresh_token') + item.expires_at = token['expires_at'] + item.save() + + +Django OpenID Connect Client +---------------------------- + +An OpenID Connect client is no different than a normal OAuth 2.0 client. When +registered with the ``openid`` scope, the built-in Django OAuth client will handle +everything automatically:: + + oauth.register( + 'google', + ... + server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', + client_kwargs={'scope': 'openid profile email'} + ) + +When we get the returned token:: + + token = oauth.google.authorize_access_token(request) + +There should be a ``id_token`` in the response. Authlib has called `.parse_id_token` +automatically, we can get ``userinfo`` in the ``token``:: + + userinfo = token['userinfo'] + +RP-Initiated Logout +------------------- + +To implement `OpenID Connect RP-Initiated Logout`_, use the ``logout_redirect`` method +to redirect users to the provider's end session endpoint:: + + def logout(request): + # Retrieve the ID token you stored during login + id_token = request.session.pop('id_token', None) + redirect_uri = request.build_absolute_uri('/logged-out') + return oauth.google.logout_redirect( + request, + post_logout_redirect_uri=redirect_uri, + id_token_hint=id_token, + ) + + def logged_out(request): + state_data = oauth.google.validate_logout_response(request) + return HttpResponse('You have been logged out.') + +.. _OpenID Connect RP-Initiated Logout: https://openid.net/specs/openid-connect-rpinitiated-1_0.html + +The ``logout_redirect`` method accepts: + +- ``request``: The Django request object (required) +- ``post_logout_redirect_uri``: Where to redirect after logout (must be registered with the provider) +- ``id_token_hint``: The ID token previously issued (recommended) +- ``state``: Opaque value for CSRF protection (auto-generated if not provided) +- ``client_id``: OAuth 2.0 Client Identifier (optional) +- ``logout_hint``: Hint about the user logging out (optional) +- ``ui_locales``: Preferred languages for the logout UI (optional) + +.. note:: + + You must store the ``id_token`` during login to use it later for logout. + The ``id_token`` is available in ``token['id_token']`` after calling + ``authorize_access_token()``. + +Find Django Google login example at https://github.com/authlib/demo-oauth-client/tree/master/django-google-login diff --git a/docs/client/fastapi.rst b/docs/oauth2/client/web/fastapi.rst similarity index 67% rename from docs/client/fastapi.rst rename to docs/oauth2/client/web/fastapi.rst index f719cc792..4aa3bf831 100644 --- a/docs/client/fastapi.rst +++ b/docs/oauth2/client/web/fastapi.rst @@ -1,11 +1,11 @@ .. _fastapi_client: -FastAPI OAuth Client -==================== +FastAPI Integration +=================== .. meta:: :description: Use Authlib built-in Starlette integrations to build - OAuth 1.0, OAuth 2.0 and OpenID Connect clients for FastAPI. + OAuth 2.0 and OpenID Connect clients for FastAPI. .. module:: authlib.integrations.starlette_client :noindex: @@ -29,35 +29,25 @@ Here is how you would create a FastAPI application:: Since Authlib starlette requires using ``request`` instance, we need to expose that ``request`` to Authlib. According to the documentation on -`Using the Request Directly `_:: +`Using the Request Directly `_:: from starlette.requests import Request - @app.get("/login") - def login_via_google(request: Request): - redirect_uri = 'https://example.com/auth' + @app.get("/login/google") + async def login_via_google(request: Request): + redirect_uri = request.url_for('auth_via_google') return await oauth.google.authorize_redirect(request, redirect_uri) - @app.get("/auth") - def auth_via_google(request: Request): + @app.get("/auth/google") + async def auth_via_google(request: Request): token = await oauth.google.authorize_access_token(request) - user = await oauth.google.parse_id_token(request, token) + user = token['userinfo'] return dict(user) .. _FastAPI: https://fastapi.tiangolo.com/ All other APIs are the same with Starlette. -FastAPI OAuth 1.0 Client ------------------------- - -We have a blog post about how to create Twitter login in FastAPI: - -https://blog.authlib.org/2020/fastapi-twitter-login - -FastAPI OAuth 2.0 Client ------------------------- - We have a blog post about how to create Google login in FastAPI: https://blog.authlib.org/2020/fastapi-google-login diff --git a/docs/client/flask.rst b/docs/oauth2/client/web/flask.rst similarity index 60% rename from docs/client/flask.rst rename to docs/oauth2/client/web/flask.rst index 2d44ed964..837de053e 100644 --- a/docs/client/flask.rst +++ b/docs/oauth2/client/web/flask.rst @@ -1,25 +1,23 @@ .. _flask_client: -Flask OAuth Client -================== +Flask Integration +================= .. meta:: - :description: The built-in Flask integrations for OAuth 1.0, OAuth 2.0 + :description: The built-in Flask integrations for OAuth 2.0 and OpenID Connect clients, powered by Authlib. .. module:: authlib.integrations.flask_client :noindex: -This documentation covers OAuth 1.0, OAuth 2.0 and OpenID Connect Client -support for Flask. Looking for OAuth providers? +This documentation covers OAuth 2.0 and OpenID Connect Client support for +Flask. Looking for OAuth 2.0 server? -- :ref:`flask_oauth1_server` - :ref:`flask_oauth2_server` -Flask OAuth client can handle OAuth 1 and OAuth 2 services. It shares a -similar API with Flask-OAuthlib, you can transfer your code from -Flask-OAuthlib to Authlib with ease. +Flask OAuth client shares a similar API with Flask-OAuthlib, you can transfer +your code from Flask-OAuthlib to Authlib with ease. Create a registry with :class:`OAuth` object:: @@ -35,24 +33,17 @@ You can also initialize it later with :meth:`~OAuth.init_app` method:: The common use case for OAuth is authentication, e.g. let your users log in with Twitter, GitHub, Google etc. -.. note:: +.. important:: Please read :ref:`frameworks_clients` at first. Authlib has a shared API design among framework integrations, learn them from :ref:`frameworks_clients`. -.. versionchanged:: v0.13 - - Authlib moved all integrations into ``authlib.integrations`` module since v0.13. - For earlier version, developers can import the Flask client with:: - - from authlib.flask.client import OAuth - Configuration ------------- Authlib Flask OAuth registry can load the configuration from Flask ``app.config`` -automatically. Every key value pair in ``.register`` can be omit. They can be -configured in your Flask App configuration. Config key is formatted with +automatically. Every key-value pair in ``.register`` can be omitted. They can be +configured in your Flask App configuration. Config keys are formatted as ``{name}_{key}`` in uppercase, e.g. ========================== ================================ @@ -62,7 +53,7 @@ TWITTER_REQUEST_TOKEN_URL URL to fetch OAuth request token ========================== ================================ If you register your remote app as ``oauth.register('example', ...)``, the -config key would look like: +config keys would look like: ========================== =============================== EXAMPLE_CLIENT_ID OAuth Consumer Key @@ -78,7 +69,7 @@ Here is a full list of the configuration keys: - ``{name}_REQUEST_TOKEN_PARAMS``: Extra parameters for Request Token endpoint - ``{name}_ACCESS_TOKEN_URL``: Access Token endpoint for OAuth 1 and OAuth 2 - ``{name}_ACCESS_TOKEN_PARAMS``: Extra parameters for Access Token endpoint -- ``{name}_AUTHORIZE_URL``: Endpoint for user authorization of OAuth 1 ro OAuth 2 +- ``{name}_AUTHORIZE_URL``: Endpoint for user authorization of OAuth 1 or OAuth 2 - ``{name}_AUTHORIZE_PARAMS``: Extra parameters for Authorization Endpoint. - ``{name}_API_BASE_URL``: A base URL endpoint to make requests simple - ``{name}_CLIENT_KWARGS``: Extra keyword arguments for OAuth1Session or OAuth2Session @@ -87,45 +78,24 @@ Here is a full list of the configuration keys: We suggest that you keep ONLY ``{name}_CLIENT_ID`` and ``{name}_CLIENT_SECRET`` in your Flask application configuration. -Using Cache for Temporary Credential ------------------------------------- - -By default, Flask OAuth registry will use Flask session to store OAuth 1.0 temporary -credential (request token). However in this way, there are chances your temporary -credential will be exposed. - -Our ``OAuth`` registry provides a simple way to store temporary credentials in a cache -system. When initializing ``OAuth``, you can pass an ``cache`` instance:: - - oauth = OAuth(app, cache=cache) - - # or initialize lazily - oauth = OAuth() - oauth.init_app(app, cache=cache) - -A ``cache`` instance MUST have methods: - -- ``.get(key)`` -- ``.set(key, value, expires=None)`` - - Routes for Authorization ------------------------ Unlike the examples in :ref:`frameworks_clients`, Flask does not pass a ``request`` into routes. In this case, the routes for authorization should look like:: - from flask import url_for, render_template + from flask import url_for, redirect @app.route('/login') def login(): redirect_uri = url_for('authorize', _external=True) - return oauth.twitter.authorize_redirect(redirect_uri) + return oauth.github.authorize_redirect(redirect_uri) @app.route('/authorize') def authorize(): - token = oauth.twitter.authorize_access_token() - resp = oauth.twitter.get('account/verify_credentials.json') + token = oauth.github.authorize_access_token() + resp = oauth.github.get('user') + resp.raise_for_status() profile = resp.json() # do something with the token and profile return redirect('/') @@ -134,7 +104,7 @@ Accessing OAuth Resources ------------------------- There is no ``request`` in accessing OAuth resources either. Just like above, -we don't need to pass ``request`` parameter, everything is handled by Authlib +we don't need to pass the ``request`` parameter, everything is handled by Authlib automatically:: from flask import render_template @@ -142,6 +112,7 @@ automatically:: @app.route('/github') def show_github_profile(): resp = oauth.github.get('user') + resp.raise_for_status() profile = resp.json() return render_template('github.html', profile=profile) @@ -150,18 +121,13 @@ In this case, our ``fetch_token`` could look like:: from your_project import current_user def fetch_token(name): - if name in OAUTH1_SERVICES: - model = OAuth1Token - else: - model = OAuth2Token - - token = model.find( + token = OAuth2Token.find( name=name, user=current_user, ) return token.to_token() - # initialize OAuth registry with this fetch_token function + # initialize the OAuth registry with this fetch_token function oauth = OAuth(fetch_token=fetch_token) You don't have to pass ``token``, you don't have to pass ``request``. That @@ -170,14 +136,11 @@ is the fantasy of Flask. Auto Update Token via Signal ---------------------------- -.. versionadded:: v0.13 - - The signal is added since v0.13 -Instead of define a ``update_token`` method and passing it into OAuth registry, -it is also possible to use signal to listen for token updating. +Instead of defining an ``update_token`` method and passing it into the OAuth registry, +it is also possible to use a signal to listen for token updating. -Before using signal, make sure you have installed **blinker** library:: +Before using the signal, make sure you have installed the **blinker** library:: $ pip install blinker @@ -205,7 +168,7 @@ Flask OpenID Connect Client --------------------------- An OpenID Connect client is no different than a normal OAuth 2.0 client. When -register with ``openid`` scope, the built-in Flask OAuth client will handle everything +registered with ``openid`` scope, the built-in Flask OAuth client will handle everything automatically:: oauth.register( @@ -219,14 +182,53 @@ When we get the returned token:: token = oauth.google.authorize_access_token() -We can get the user information from the ``id_token`` in the returned token:: +There should be a ``id_token`` in the response. Authlib has called `.parse_id_token` +automatically, we can get ``userinfo`` in the ``token``:: + + userinfo = token['userinfo'] + +RP-Initiated Logout +------------------- + +To implement `OpenID Connect RP-Initiated Logout`_, use the ``logout_redirect`` method +to redirect users to the provider's end session endpoint:: + + @app.route('/logout') + def logout(): + # Retrieve the ID token you stored during login + id_token = session.pop('id_token', None) + return oauth.google.logout_redirect( + post_logout_redirect_uri=url_for('logged_out', _external=True), + id_token_hint=id_token, + ) + + @app.route('/logged-out') + def logged_out(): + state_data = oauth.google.validate_logout_response() + return 'You have been logged out.' + +.. _OpenID Connect RP-Initiated Logout: https://openid.net/specs/openid-connect-rpinitiated-1_0.html - userinfo = oauth.google.parse_id_token(token) +The ``logout_redirect`` method accepts: + +- ``post_logout_redirect_uri``: Where to redirect after logout (must be registered with the provider) +- ``id_token_hint``: The ID token previously issued (recommended) +- ``state``: Opaque value for CSRF protection (auto-generated if not provided) +- ``client_id``: OAuth 2.0 Client Identifier (optional) +- ``logout_hint``: Hint about the user logging out (optional) +- ``ui_locales``: Preferred languages for the logout UI (optional) + +.. note:: + + You must store the ``id_token`` during login to use it later for logout. + The ``id_token`` is available in ``token['id_token']`` after calling + ``authorize_access_token()``. Examples --------- -Here are some example code for you learn Flask OAuth client integrations: +Here are some example projects for you to learn Flask OAuth 2.0 client integrations: + +1. `Flask Google Login`_. -1. OAuth 1.0: `Flask Twitter login `_ -2. OAuth 2.0 & OpenID Connect: `Flask Google login `_ +.. _`Flask Google Login`: https://github.com/authlib/demo-oauth-client/tree/master/flask-google-login diff --git a/docs/client/frameworks.rst b/docs/oauth2/client/web/index.rst similarity index 64% rename from docs/client/frameworks.rst rename to docs/oauth2/client/web/index.rst index 9cb803df4..5f7b9978f 100644 --- a/docs/client/frameworks.rst +++ b/docs/oauth2/client/web/index.rst @@ -1,13 +1,12 @@ .. _frameworks_clients: -Web OAuth Clients -================= +Web Clients +=========== .. module:: authlib.integrations :noindex: -This documentation covers OAuth 1.0 and OAuth 2.0 integrations for -Python Web Frameworks like: +This documentation covers OAuth 2.0 integrations for Python Web Frameworks like: * Django: The web framework for perfectionists with deadlines * Flask: The Python micro framework for building web applications @@ -39,99 +38,7 @@ documentation later: 3. :class:`starlette_client.OAuth` for :ref:`starlette_client` The common use case for OAuth is authentication, e.g. let your users log in -with Twitter, GitHub, Google etc. - -Log In with OAuth 1.0 ---------------------- - -For instance, Twitter is an OAuth 1.0 service, you want your users to log in -your website with Twitter. - -The first step is register a remote application on the ``OAuth`` registry via -``oauth.register`` method:: - - oauth.register( - name='twitter', - client_id='{{ your-twitter-consumer-key }}', - client_secret='{{ your-twitter-consumer-secret }}', - request_token_url='https://api.twitter.com/oauth/request_token', - request_token_params=None, - access_token_url='https://api.twitter.com/oauth/access_token', - access_token_params=None, - authorize_url='https://api.twitter.com/oauth/authenticate', - authorize_params=None, - api_base_url='https://api.twitter.com/1.1/', - client_kwargs=None, - ) - -The first parameter in ``register`` method is the **name** of the remote -application. You can access the remote application with:: - - twitter = oauth.create_client('twitter') - # or simply with - twitter = oauth.twitter - -The configuration of those parameters can be loaded from the framework -configuration. Each framework has its own config system, read the framework -specified documentation later. - -For instance, if ``client_id`` and ``client_secret`` can be loaded via -configuration, we can simply register the remote app with:: - - oauth.register( - name='twitter', - request_token_url='https://api.twitter.com/oauth/request_token', - access_token_url='https://api.twitter.com/oauth/access_token', - authorize_url='https://api.twitter.com/oauth/authenticate', - api_base_url='https://api.twitter.com/1.1/', - ) - -The ``client_kwargs`` is a dict configuration to pass extra parameters to -:ref:`oauth_1_session`. If you are using ``RSA-SHA1`` signature method:: - - client_kwargs = { - 'signature_method': 'RSA-SHA1', - 'signature_type': 'HEADER', - 'rsa_key': 'Your-RSA-Key' - } - - -Saving Temporary Credential -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Usually, the framework integration has already implemented this part through -the framework session system. All you need to do is enable session for the -chosen framework. - -Routes for Authorization -~~~~~~~~~~~~~~~~~~~~~~~~ - -After configuring the ``OAuth`` registry and the remote application, the -rest steps are much simpler. The only required parts are routes: - -1. redirect to 3rd party provider (Twitter) for authentication -2. redirect back to your website to fetch access token and profile - -Here is the example for Twitter login:: - - def login(request): - twitter = oauth.create_client('twitter') - redirect_uri = 'https://example.com/authorize' - return twitter.authorize_redirect(request, redirect_uri) - - def authorize(request): - twitter = oauth.create_client('twitter') - token = twitter.authorize_access_token(request) - resp = twitter.get('account/verify_credentials.json') - profile = resp.json() - # do something with the token and profile - return '...' - -After user confirmed on Twitter authorization page, it will redirect -back to your website ``authorize`` page. In this route, you can get your -user's twitter profile information, you can store the user information -in your database, mark your user as logged in and etc. - +with GitHub, Google etc. Using OAuth 2.0 to Log In ------------------------- @@ -166,7 +73,7 @@ configuration. Each framework has its own config system, read the framework specified documentation later. The ``client_kwargs`` is a dict configuration to pass extra parameters to -:ref:`oauth_2_session`, you can pass extra parameters like:: +:ref:`OAuth 2 Session `, you can pass extra parameters like:: client_kwargs = { 'scope': 'profile', @@ -177,12 +84,6 @@ The ``client_kwargs`` is a dict configuration to pass extra parameters to There are several ``token_endpoint_auth_method``, get a deep inside the :ref:`client_auth_methods`. -.. note:: - - Authlib is using ``request_token_url`` to detect if the client is an - OAuth 1.0 or OAuth 2.0 client. In OAuth 2.0, there is no ``request_token_url``. - - Routes for Authorization ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -202,6 +103,7 @@ Here is the example for GitHub login:: def authorize(request): token = oauth.github.authorize_access_token(request) resp = oauth.github.get('user', token=token) + resp.raise_for_status() profile = resp.json() # do something with the token and profile return '...' @@ -211,17 +113,6 @@ back to your website ``authorize``. In this route, you can get your user's GitHub profile information, you can store the user information in your database, mark your user as logged in and etc. -.. note:: - - You may find that our documentation for OAuth 1.0 and OAuth 2.0 are - the same. They are designed to share the same API, so that you use - the same code for both OAuth 1.0 and OAuth 2.0. - - The ONLY difference is the configuration. OAuth 1.0 contains - ``request_token_url`` and ``request_token_params`` while OAuth 2.0 - not. Also, the ``client_kwargs`` are different. - - Client Authentication Methods ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -262,22 +153,13 @@ Accessing OAuth Resources .. note:: If your application ONLY needs login via 3rd party services like - Twitter, Google, Facebook and GitHub to login, you DON'T need to + Google, Facebook and GitHub to login, you DON'T need to create the token database. There are also chances that you need to access your user's 3rd party OAuth provider resources. For instance, you want to display the logged -in user's twitter time line and GitHub repositories. You will use -**access token** to fetch the resources:: - - def get_twitter_tweets(request): - token = OAuth1Token.find( - name='twitter', - user=request.user - ) - # API URL: https://api.twitter.com/1.1/statuses/user_timeline.json - resp = oauth.twitter.get('statuses/user_timeline.json', token=token.to_token()) - return resp.json() +in user's GitHub repositories. You will use **access token** to fetch +the resources:: def get_github_repositories(request): token = OAuth2Token.find( @@ -286,6 +168,7 @@ in user's twitter time line and GitHub repositories. You will use ) # API URL: https://api.github.com/user/repos resp = oauth.github.get('user/repos', token=token.to_token()) + resp.raise_for_status() return resp.json() In this case, we need a place to store the access token in order to use @@ -297,24 +180,7 @@ database. Design Database ~~~~~~~~~~~~~~~ -It is possible to share one database table for both OAuth 1.0 token and -OAuth 2.0 token. It is also good to use different database tables for -OAuth 1.0 and OAuth 2.0. - -In the above example, we are using two tables. Here are some hints on -how to design the database:: - - class OAuth1Token(Model): - name = String(length=40) - oauth_token = String(length=200) - oauth_token_secret = String(length=200) - user = ForeignKey(User) - - def to_token(self): - return dict( - oauth_token=self.access_token, - oauth_token_secret=self.alt_token, - ) +Here are some hints on how to design the OAuth 2.0 token database:: class OAuth2Token(Model): name = String(length=40) @@ -343,12 +209,6 @@ Fetch User OAuth Token You can always pass a ``token`` parameter to the remote application request methods, like:: - token = OAuth1Token.find(name='twitter', user=request.user) - oauth.twitter.get(url, token=token) - oauth.twitter.post(url, token=token) - oauth.twitter.put(url, token=token) - oauth.twitter.delete(url, token=token) - token = OAuth2Token.find(name='github', user=request.user) oauth.github.get(url, token=token) oauth.github.post(url, token=token) @@ -359,13 +219,6 @@ However, it is not a good practice to query the token database in every request function. Authlib provides a way to fetch current user's token automatically for you, just ``register`` with ``fetch_token`` function:: - def fetch_twitter_token(request): - token = OAuth1Token.find( - name='twitter', - user=request.user - ) - return token.to_token() - def fetch_github_token(request): token = OAuth2Token.find( name='github', @@ -374,27 +227,16 @@ you, just ``register`` with ``fetch_token`` function:: return token.to_token() # we can registry this ``fetch_token`` with oauth.register - oauth.register( - 'twitter', - # ... - fetch_token=fetch_twitter_token, - ) oauth.register( 'github', # ... fetch_token=fetch_github_token, ) -Not good enough. In this way, you have to write ``fetch_token`` for every -remote application. There is also a shared way to fetch token:: +There is also a shared way to fetch token:: def fetch_token(name, request): - if name in OAUTH1_SERVICES: - model = OAuth1Token - else: - model = OAuth2Token - - token = model.find( + token = OAuth2Token.find( name=name, user=request.user ) @@ -406,8 +248,9 @@ remote application. There is also a shared way to fetch token:: Now, developers don't have to pass a ``token`` in the HTTP requests, instead, they can pass the ``request``:: - def get_twitter_tweets(request): - resp = oauth.twitter.get('statuses/user_timeline.json', request=request) + def get_github_repos(request): + resp = oauth.github.get('user/repos', request=request) + resp.raise_for_status() return resp.json() @@ -479,6 +322,7 @@ a method to fix the requests session:: def slack_compliance_fix(session): def _fix(resp): + resp.raise_for_status() token = resp.json() # slack returns no token_type token['token_type'] = 'Bearer' @@ -503,7 +347,7 @@ Find all the available compliance hooks at :ref:`compliance_fix_oauth2`. OpenID Connect & UserInfo ------------------------- -When log in with OAuth 1.0 and OAuth 2.0, "access_token" is not what developers +When logging in with OpenID Connect, "access_token" is not what developers want. Instead, what developers want is **user info**, Authlib wrap it with :class:`~authlib.oidc.core.UserInfo`. @@ -524,30 +368,13 @@ Passing a ``userinfo_endpoint`` when ``.register`` remote client:: userinfo_endpoint='https://openidconnect.googleapis.com/v1/userinfo', ) -And later, when the client has obtained access token, we can call:: +And later, when the client has obtained the access token, we can call:: def authorize(request): token = oauth.google.authorize_access_token(request) - user = oauth.google.userinfo(request) + user = oauth.google.userinfo(token=token) return '...' -If the ``userinfo_endpoint`` is not compatible with -:class:`~authlib.oidc.core.UserInfo`, we can use a ``userinfo_compliance_fix``:: - - - def compliance_fix(client, user_data): - return { - 'sub': user_data['id'], - 'name': user_data['name'] - } - - oauth.register( - 'example', - client_id='...', - client_secret='...', - userinfo_endpoint='https://example.com/userinfo', - userinfo_compliance_fix=compliance_fix, - ) Parsing ``id_token`` ~~~~~~~~~~~~~~~~~~~~ @@ -570,7 +397,7 @@ A simple solution is to provide the OpenID Connect Discovery Endpoint:: client_kwargs={'scope': 'openid email profile'}, ) -The discovery endpoint provides all the information we need so that you don't +The discovery endpoint provides all the information we need so that we don't have to add ``authorize_url`` and ``access_token_url``. Check out our client example: https://github.com/authlib/demo-oauth-client @@ -593,3 +420,58 @@ provide the value of ``jwks`` instead of ``jwks_uri``:: authorize_url='https://example.com/oauth/authorize', jwks={"keys": [...]} ) + + +RP-Initiated Logout +------------------- + +`OpenID Connect RP-Initiated Logout`_ allows users to log out from the +OpenID Provider when they log out from your application. This is useful +to ensure that the user's session at the provider is also terminated. + +.. _OpenID Connect RP-Initiated Logout: https://openid.net/specs/openid-connect-rpinitiated-1_0.html + +To use RP-Initiated Logout, the provider must support the ``end_session_endpoint`` +in its OpenID Connect discovery document. Authlib provides a ``logout_redirect`` +method to redirect users to this endpoint:: + + def logout(request): + client = oauth.create_client('google') + # Retrieve the ID token you stored during login + id_token = get_stored_id_token(request.user) + return client.logout_redirect( + request, + post_logout_redirect_uri='https://example.com/logged-out', + id_token_hint=id_token, + ) + +The ``logout_redirect`` method accepts: + +- ``post_logout_redirect_uri``: Where to redirect after logout (must be registered with the provider) +- ``id_token_hint``: The ID token previously issued (recommended) +- ``state``: Opaque value for CSRF protection (auto-generated if not provided) +- ``client_id``: OAuth 2.0 Client Identifier (optional) +- ``logout_hint``: Hint about the user logging out (optional) +- ``ui_locales``: Preferred languages for the logout UI (optional) + +.. important:: + + You must store the ``id_token`` during login to use it later for logout. + The ``id_token`` is available in ``token['id_token']`` after calling + ``authorize_access_token()``. + +Each framework has slightly different syntax. See the framework-specific documentation +for detailed examples: + +- :ref:`flask_client` +- :ref:`django_client` +- :ref:`starlette_client` + +.. toctree:: + :maxdepth: 1 + + flask + django + starlette + fastapi + api diff --git a/docs/oauth2/client/web/starlette.rst b/docs/oauth2/client/web/starlette.rst new file mode 100644 index 000000000..39c98c0ee --- /dev/null +++ b/docs/oauth2/client/web/starlette.rst @@ -0,0 +1,139 @@ +.. _starlette_client: + +Starlette Integration +===================== + +.. meta:: + :description: The built-in Starlette integrations for OAuth 2.0 + and OpenID Connect clients, powered by Authlib. + +.. module:: authlib.integrations.starlette_client + :noindex: + +Starlette_ is a lightweight ASGI framework/toolkit, which is ideal for +building high performance asyncio services. + +.. _Starlette: https://www.starlette.io/ + +This documentation covers OAuth 2.0 and OpenID Connect Client support for +Starlette. Because all the frameworks integrations share the same API, it is +best to: + +Read :ref:`frameworks_clients` at first. + +The difference between Starlette and Flask/Django integrations is Starlette +is **async**. We will use ``await`` for the functions we need to call. But +first, let's create an :class:`OAuth` instance:: + + from authlib.integrations.starlette_client import OAuth + + oauth = OAuth() + +The common use case for OAuth is authentication, e.g. let your users log in +with Twitter, GitHub, Google etc. + +Register Remote Apps +-------------------- + +``oauth.register`` is the same as :ref:`frameworks_clients`:: + + oauth.register( + 'google', + client_id='...', + client_secret='...', + ... + ) + +However, unlike Flask/Django, Starlette OAuth registry uses HTTPX +:class:`~authlib.integrations.httpx_client.AsyncOAuth2Client` as the OAuth 2.0 +backend. While Flask and Django are using the Requests version of +:class:`~authlib.integrations.requests_client.OAuth2Session`. + +Routes for Authorization +------------------------ + +Just like the examples in :ref:`frameworks_clients`, but Starlette is **async**, +the routes for authorization should look like:: + + @app.route('/login/google') + async def login_via_google(request): + google = oauth.create_client('google') + redirect_uri = request.url_for('authorize_google') + return await google.authorize_redirect(request, redirect_uri) + + @app.route('/auth/google') + async def authorize_google(request): + google = oauth.create_client('google') + token = await google.authorize_access_token(request) + # do something with the token and userinfo + return '...' + +Starlette OpenID Connect +------------------------ + +An OpenID Connect client is no different than a normal OAuth 2.0 client, just add +``openid`` scope when ``.register``. The built-in Starlette OAuth client will handle +everything automatically:: + + oauth.register( + 'google', + ... + server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', + client_kwargs={'scope': 'openid profile email'} + ) + +When we get the returned token:: + + token = await oauth.google.authorize_access_token() + +There should be a ``id_token`` in the response. Authlib has called `.parse_id_token` +automatically, we can get ``userinfo`` in the ``token``:: + + userinfo = token['userinfo'] + +RP-Initiated Logout +------------------- + +To implement `OpenID Connect RP-Initiated Logout`_, use the ``logout_redirect`` method +to redirect users to the provider's end session endpoint:: + + @app.route('/logout') + async def logout(request): + # Retrieve the ID token you stored during login + id_token = request.session.pop('id_token', None) + redirect_uri = request.url_for('logged_out') + return await oauth.google.logout_redirect( + request, + post_logout_redirect_uri=str(redirect_uri), + id_token_hint=id_token, + ) + + @app.route('/logged-out') + async def logged_out(request): + state_data = await oauth.google.validate_logout_response(request) + return PlainTextResponse('You have been logged out.') + +.. _OpenID Connect RP-Initiated Logout: https://openid.net/specs/openid-connect-rpinitiated-1_0.html + +The ``logout_redirect`` method accepts: + +- ``request``: The Starlette request object (required) +- ``post_logout_redirect_uri``: Where to redirect after logout (must be registered with the provider) +- ``id_token_hint``: The ID token previously issued (recommended) +- ``state``: Opaque value for CSRF protection (auto-generated if not provided) +- ``client_id``: OAuth 2.0 Client Identifier (optional) +- ``logout_hint``: Hint about the user logging out (optional) +- ``ui_locales``: Preferred languages for the logout UI (optional) + +.. note:: + + You must store the ``id_token`` during login to use it later for logout. + The ``id_token`` is available in ``token['id_token']`` after calling + ``authorize_access_token()``. + +Examples +-------- + +We have Starlette demos at https://github.com/authlib/demo-oauth-client + +1. `Starlette Google login `_ diff --git a/docs/oauth/2/intro.rst b/docs/oauth2/concepts.rst similarity index 98% rename from docs/oauth/2/intro.rst rename to docs/oauth2/concepts.rst index 9e4f2039a..861622251 100644 --- a/docs/oauth/2/intro.rst +++ b/docs/oauth2/concepts.rst @@ -5,8 +5,8 @@ .. _intro_oauth2: -Introduce OAuth 2.0 -=================== +Concepts +======== The OAuth 2.0 authorization framework enables a third-party application to obtain limited access to an HTTP service, either on behalf of a resource owner @@ -151,7 +151,7 @@ Scope is a very important concept in OAuth 2.0. An access token is usually issue with limited scopes. For instance, your "source code analyzer" application MAY only have access to the -public repositories of a GiHub user. +public repositories of a GitHub user. Endpoints --------- diff --git a/docs/oauth2/index.rst b/docs/oauth2/index.rst new file mode 100644 index 000000000..43c94d0b6 --- /dev/null +++ b/docs/oauth2/index.rst @@ -0,0 +1,11 @@ +OAuth 2.0 & OIDC +================ + +.. toctree:: + :maxdepth: 2 + + concepts + client/index + authorization-server/index + resource-server/index + specs/index diff --git a/docs/django/2/resource-server.rst b/docs/oauth2/resource-server/django.rst similarity index 72% rename from docs/django/2/resource-server.rst rename to docs/oauth2/resource-server/django.rst index a1e32815a..e312935f9 100644 --- a/docs/django/2/resource-server.rst +++ b/docs/oauth2/resource-server/django.rst @@ -1,5 +1,7 @@ -Resource Server -=============== +.. _django_resource_server: + +Django Integration +================== Protect users resources, so that only the authorized clients with the authorized access token can access the given scope resources. @@ -18,16 +20,9 @@ server. Here is the way to protect your users' resources in Django:: user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) -If the resource is not protected by a scope, use ``None``:: - - @require_oauth() - def user_profile(request): - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - - # or with None +If the resource is not protected by a scope, omit the argument:: - @require_oauth(None) + @require_oauth def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) @@ -38,12 +33,14 @@ which is the instance of current in-use Token. Multiple Scopes --------------- -You can apply multiple scopes to one endpoint in **AND** and **OR** modes. -The default is **AND** mode. +.. versionchanged:: v1.0 + +You can apply multiple scopes to one endpoint in **AND**, **OR** and mix modes. +Here are some examples: .. code-block:: python - @require_oauth('profile email', 'AND') + @require_oauth(['profile email']) def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) @@ -52,24 +49,26 @@ It requires the token containing both ``profile`` and ``email`` scope. .. code-block:: python - @require_oauth('profile email', 'OR') + @require_oauth(['profile', 'email']) def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) It requires the token containing either ``profile`` or ``email`` scope. -It is also possible to pass a function as the scope operator. e.g.:: - def scope_operator(token_scopes, resource_scopes): - # this equals "AND" - return token_scopes.issuperset(resource_scopes) +It is also possible to mix **AND** and **OR** logic. e.g.:: - @require_oauth('profile email', scope_operator) + @app.route('/profile') + @require_oauth(['profile email', 'user']) def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) +This means if the token will be valid if: + +1. token contains both ``profile`` and ``email`` scope +2. or token contains ``user`` scope Optional ``require_oauth`` -------------------------- diff --git a/docs/flask/2/resource-server.rst b/docs/oauth2/resource-server/flask.rst similarity index 66% rename from docs/flask/2/resource-server.rst rename to docs/oauth2/resource-server/flask.rst index 2bbbef7bf..b6cab8a86 100644 --- a/docs/flask/2/resource-server.rst +++ b/docs/oauth2/resource-server/flask.rst @@ -1,13 +1,13 @@ .. _flask_oauth2_resource_protector: -Resource Server -=============== +Flask Integration +================= -Protect users resources, so that only the authorized clients with the +Protects users resources, so that only the authorized clients with the authorized access token can access the given scope resources. A resource server can be a different server other than the authorization -server. Here is the way to protect your users' resources:: +server. Authlib offers a **decorator** to protect your API endpoints:: from flask import jsonify from authlib.integrations.flask_oauth2 import ResourceProtector, current_token @@ -17,41 +17,28 @@ server. Here is the way to protect your users' resources:: def authenticate_token(self, token_string): return Token.query.filter_by(access_token=token_string).first() - def request_invalid(self, request): - return False - - def token_revoked(self, token): - return token.revoked - require_oauth = ResourceProtector() # only bearer token is supported currently require_oauth.register_token_validator(MyBearerTokenValidator()) - # you can also create BearerTokenValidator with shortcut - from authlib.integrations.sqla_oauth2 import create_bearer_token_validator +When the resource server has no access to the ``Token`` model (database), and +there is an introspection token endpoint in authorization server, you can +:ref:`require_oauth_introspection`. - BearerTokenValidator = create_bearer_token_validator(db.session, Token) - require_oauth.register_token_validator(BearerTokenValidator()) +Here is the way to protect your users' resources:: @app.route('/user') @require_oauth('profile') def user_profile(): + # if Token model has `.user` foreign key user = current_token.user return jsonify(user) -If the resource is not protected by a scope, use ``None``:: - - @app.route('/user') - @require_oauth() - def user_profile(): - user = current_token.user - return jsonify(user) - - # or with None +If the resource is not protected by a scope, omit the argument:: @app.route('/user') - @require_oauth(None) + @require_oauth def user_profile(): user = current_token.user return jsonify(user) @@ -60,7 +47,7 @@ The ``current_token`` is a proxy to the Token model you have defined above. Since there is a ``user`` relationship on the Token model, we can access this ``user`` with ``current_token.user``. -If decorator is not your favorite, there is a ``with`` statement for you:: +If the decorator is not your favorite, there is a ``with`` statement for you:: @app.route('/user') def user_profile(): @@ -73,13 +60,15 @@ If decorator is not your favorite, there is a ``with`` statement for you:: Multiple Scopes --------------- -You can apply multiple scopes to one endpoint in **AND** and **OR** modes. -The default is **AND** mode. +.. versionchanged:: v1.0 + +You can apply multiple scopes to one endpoint in **AND**, **OR** and mix modes. +Here are some examples: .. code-block:: python @app.route('/profile') - @require_oauth('profile email', 'AND') + @require_oauth(['profile email']) def user_profile(): user = current_token.user return jsonify(user) @@ -89,25 +78,25 @@ It requires the token containing both ``profile`` and ``email`` scope. .. code-block:: python @app.route('/profile') - @require_oauth('profile email', 'OR') + @require_oauth(['profile', 'email']') def user_profile(): user = current_token.user return jsonify(user) It requires the token containing either ``profile`` or ``email`` scope. -It is also possible to pass a function as the scope operator. e.g.:: - - def scope_operator(token_scopes, resource_scopes): - # this equals "AND" - return token_scopes.issuperset(resource_scopes) +It is also possible to mix **AND** and **OR** logic. e.g.:: @app.route('/profile') - @require_oauth('profile email', scope_operator) + @require_oauth(['profile email', 'user']) def user_profile(): user = current_token.user return jsonify(user) +This means if the token will be valid if: + +1. token contains both ``profile`` and ``email`` scope +2. or token contains ``user`` scope Optional ``require_oauth`` -------------------------- @@ -138,4 +127,3 @@ and ``flask_restful.Resource``:: class UserAPI(Resource): method_decorators = [require_oauth('profile')] - diff --git a/docs/oauth2/resource-server/index.rst b/docs/oauth2/resource-server/index.rst new file mode 100644 index 000000000..b6ad5831d --- /dev/null +++ b/docs/oauth2/resource-server/index.rst @@ -0,0 +1,38 @@ +.. _resource_server: + +Resource Server +=============== + +A Resource Server is an API that accepts and validates access tokens to protect +resources. Build this when you want to secure your API endpoints so that only +authorized clients can access them. + +A resource server can be separate from the authorization server — it only needs +to validate the tokens that the authorization server issued. + +Not sure this is the right role? See :ref:`intro_oauth2` for an overview of +all OAuth 2.0 roles. + +Looking for the :ref:`authorization_server` (issuing tokens)? +Or the :ref:`oauth_client` (consuming an OAuth provider)? + +Understand +---------- + +* :ref:`intro_oauth2` — OAuth 2.0 roles and token validation +* :doc:`../specs/rfc6750` — Bearer Token Usage + +How-to +------ + +.. toctree:: + :maxdepth: 2 + + flask + django + +Reference +--------- + +* :doc:`../specs/rfc6750` — Bearer Token Usage +* :doc:`../specs/rfc7662` — Token Introspection diff --git a/docs/oauth2/specs/index.rst b/docs/oauth2/specs/index.rst new file mode 100644 index 000000000..2bdef3fd5 --- /dev/null +++ b/docs/oauth2/specs/index.rst @@ -0,0 +1,21 @@ +Specifications +============== + +.. toctree:: + :maxdepth: 1 + + rfc6749 + rfc6750 + rfc7009 + rfc7523 + rfc7591 + rfc7592 + rfc7636 + rfc7662 + rfc8414 + rfc8628 + rfc9068 + rfc9101 + rfc9207 + oidc + rpinitiated diff --git a/docs/oauth2/specs/oidc.rst b/docs/oauth2/specs/oidc.rst new file mode 100644 index 000000000..dcaf9fa61 --- /dev/null +++ b/docs/oauth2/specs/oidc.rst @@ -0,0 +1,96 @@ +.. _specs/oidc: + +OpenID Connect 1.0 +================== + +.. meta:: + :description: General implementation of OpenID Connect 1.0 in Python. + Learn how to create a OpenID Connect provider in Python. + +This part of the documentation covers the specification of OpenID Connect. Learn +how to use it in :ref:`flask_oidc_server` and :ref:`django_oidc_server`. + +OpenID Grants +------------- + +.. module:: authlib.oidc.core.grants + +.. autoclass:: OpenIDToken + :show-inheritance: + :members: + +.. autoclass:: OpenIDCode + :show-inheritance: + :members: + +.. autoclass:: OpenIDImplicitGrant + :show-inheritance: + :members: + +.. autoclass:: OpenIDHybridGrant + :show-inheritance: + :members: + +OpenID Endpoints +---------------- + +.. module:: authlib.oidc.core + +.. autoclass:: UserInfoEndpoint + :show-inheritance: + :members: + +OpenID Claims +------------- + +.. module:: authlib.oidc.core.claims + +.. autoclass:: IDToken + :show-inheritance: + :members: + + +.. autoclass:: CodeIDToken + :show-inheritance: + :members: + + +.. autoclass:: ImplicitIDToken + :show-inheritance: + :members: + + +.. autoclass:: HybridIDToken + :show-inheritance: + :members: + +.. autoclass:: UserInfo + :members: + +Dynamic client registration +--------------------------- + +The `OpenID Connect Dynamic Client Registration `__ implementation is based on :ref:`RFC7591: OAuth 2.0 Dynamic Client Registration Protocol `. To handle OIDC client registration, you can extend your RFC7591 registration endpoint with OIDC claims:: + + from authlib.oauth2.rfc7591 import ClientMetadataClaims as OAuth2ClientMetadataClaims + from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint + from authlib.oidc.registration import ClientMetadataClaims as OIDCClientMetadataClaims + + class MyClientRegistrationEndpoint(ClientRegistrationEndpoint): + ... + + def get_server_metadata(self): + ... + + authorization_server.register_endpoint( + MyClientRegistrationEndpoint( + claims_classes=[OAuth2ClientMetadataClaims, OIDCClientMetadataClaims] + ) + ) + + + +.. automodule:: authlib.oidc.registration + :show-inheritance: + :members: + diff --git a/docs/specs/rfc6749.rst b/docs/oauth2/specs/rfc6749.rst similarity index 96% rename from docs/specs/rfc6749.rst rename to docs/oauth2/specs/rfc6749.rst index 8174b0311..f0273d563 100644 --- a/docs/specs/rfc6749.rst +++ b/docs/oauth2/specs/rfc6749.rst @@ -11,7 +11,7 @@ This section contains the generic implementation of RFC6749_. You should read :ref:`intro_oauth2` at first. Here are some tips: 1. Have a better understanding of :ref:`OAuth 2.0 ` -2. How to use :ref:`oauth_2_session` for Requests +2. How to use :ref:`OAuth 2 Session ` for Requests 3. How to implement :ref:`flask_client` 4. How to implement :ref:`flask_oauth2_server` 5. How to implement :ref:`django_client` diff --git a/docs/specs/rfc6750.rst b/docs/oauth2/specs/rfc6750.rst similarity index 100% rename from docs/specs/rfc6750.rst rename to docs/oauth2/specs/rfc6750.rst diff --git a/docs/specs/rfc7009.rst b/docs/oauth2/specs/rfc7009.rst similarity index 100% rename from docs/specs/rfc7009.rst rename to docs/oauth2/specs/rfc7009.rst diff --git a/docs/specs/rfc7523.rst b/docs/oauth2/specs/rfc7523.rst similarity index 76% rename from docs/specs/rfc7523.rst rename to docs/oauth2/specs/rfc7523.rst index d38e72cb4..afeb9422a 100644 --- a/docs/specs/rfc7523.rst +++ b/docs/oauth2/specs/rfc7523.rst @@ -21,6 +21,9 @@ This section contains the generic Python implementation of RFC7523_. Using JWTs as Authorization Grants ---------------------------------- +.. versionchanged:: v1.0.0 + Please note that all not-implemented methods are changed. + JWT Profile for OAuth 2.0 Authorization Grants works in the same way with :ref:`RFC6749 ` built-in grants. Which means it can be registered with :meth:`~authlib.oauth2.rfc6749.AuthorizationServer.register_grant`. @@ -28,24 +31,36 @@ registered with :meth:`~authlib.oauth2.rfc6749.AuthorizationServer.register_gran The base class is :class:`JWTBearerGrant`, you need to implement the missing methods in order to use it. Here is an example:: - from authlib.jose import jwk + from joserfc.jwk import KeySet from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant class JWTBearerGrant(_JWTBearerGrant): - def authenticate_user(self, client, claims): - # get user from claims info, usually it is claims['sub'] - # for anonymous user, return None - return None - - def authenticate_client(self, claims): - # get client from claims, usually it is claims['iss'] - # since the assertion JWT is generated by this client - return get_client_by_iss(claims['iss']) - - def resolve_public_key(self, headers, payload): - # get public key to decode the assertion JWT - jwk_set = get_client_public_keys(claims['iss']) - return jwk.loads(jwk_set, header.get('kid')) + def get_audiences(self): + # Per RFC 7523 Section 3, both the token endpoint URL and the + # authorization server's issuer identifier are valid audience values. + return ['https://example.com/oauth/token', 'https://example.com'] + + def resolve_issuer_client(self, issuer): + # if using client_id as issuer + return Client.objects.get(client_id=issuer) + + def resolve_client_key(self, client, headers, payload): + # if client has `jwks` column + key_set = KeySet.import_key_set(client.jwks) + return key_set + + def authenticate_user(self, subject): + # when assertion contains `sub` value, if this `sub` is email + return User.objects.get(email=subject) + + def has_granted_permission(self, client, user): + # check if the client has access to user's resource. + # for instance, we have a table `UserGrant`, which user can add client + # to this table to record that client has granted permission + grant = UserGrant.objects.get(client_id=client.client_id, user_id=user.id) + if grant: + return grant.enabled + return False # register grant to authorization server authorization_server.register_grant(JWTBearerGrant) @@ -80,6 +95,11 @@ In Authlib, ``client_secret_jwt`` and ``private_key_jwt`` share the same API, using :class:`JWTBearerClientAssertion` to create a new client authentication:: class JWTClientAuth(JWTBearerClientAssertion): + def get_audiences(self): + # Per RFC 7523 Section 3, both the token endpoint URL and the + # authorization server's issuer identifier are valid audience values. + return ['https://example.com/oauth/token', 'https://example.com'] + def validate_jti(self, claims, jti): # validate_jti is required by OpenID Connect # but it is optional by RFC7523 @@ -99,11 +119,13 @@ using :class:`JWTBearerClientAssertion` to create a new client authentication:: authorization_server.register_client_auth_method( JWTClientAuth.CLIENT_AUTH_METHOD, - JWTClientAuth('https://example.com/oauth/token') + JWTClientAuth() ) -The value ``https://example.com/oauth/token`` is your authorization servers's -token endpoint, which is used as ``aud`` value in JWT. +The ``get_audiences`` method returns the list of valid ``aud`` values +accepted in client assertion JWTs. Per RFC 7523 Section 3, +both the token endpoint URL and the authorization server's issuer +identifier are valid audience values. Now we have added this client auth method to authorization server, but no grant types support this authentication method, you need to add it to the @@ -141,7 +163,7 @@ alg, in this case, you can also alter ``CLIENT_AUTH_METHOD = 'client_secret_jwt' Using JWTs Client Assertion in OAuth2Session -------------------------------------------- -Authlib RFC7523 provides two more client authentication methods for :ref:`oauth_2_session`: +Authlib RFC7523 provides two more client authentication methods for :ref:`OAuth 2 Session `: 1. ``client_secret_jwt`` 2. ``private_key_jwt`` diff --git a/docs/specs/rfc7591.rst b/docs/oauth2/specs/rfc7591.rst similarity index 91% rename from docs/specs/rfc7591.rst rename to docs/oauth2/specs/rfc7591.rst index 60c34b912..82e12a8ec 100644 --- a/docs/specs/rfc7591.rst +++ b/docs/oauth2/specs/rfc7591.rst @@ -26,12 +26,6 @@ POST request. The metadata may contain a :ref:`JWT ` ``software_state value. Endpoint can choose if it support ``software_statement``, it is not enabled by default. -.. versionchanged:: v0.15 - - ClientRegistrationEndpoint has a breaking change in v0.15. Method of - ``authenticate_user`` is replaced by ``authenticate_token``, and parameters in - ``save_client`` is also changed. - Before register the endpoint, developers MUST implement the missing methods:: from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint @@ -55,6 +49,12 @@ Before register the endpoint, developers MUST implement the missing methods:: client.save() return client + def generate_client_registration_info(self, client, request): + return { + 'registration_client_uri': url_for("client_management_endpoint", client_id=client.client_id), + 'registration_access_token': create_management_access_token_for_client(client), + } + If developers want to support ``software_statement``, additional methods should be implemented:: diff --git a/docs/oauth2/specs/rfc7592.rst b/docs/oauth2/specs/rfc7592.rst new file mode 100644 index 000000000..f944a7825 --- /dev/null +++ b/docs/oauth2/specs/rfc7592.rst @@ -0,0 +1,99 @@ +.. _specs/rfc7592: + +RFC7592: OAuth 2.0 Dynamic Client Registration Management Protocol +================================================================== + +This section contains the generic implementation of RFC7592_. OAuth 2.0 Dynamic +Client Registration Management Protocol allows developers edit and delete OAuth +client via API through Authorization Server. This specification is an extension +of :ref:`specs/rfc7591`. + + +.. meta:: + :description: Python API references on RFC7592 OAuth 2.0 Dynamic Client + Registration Management Protocol in Python with Authlib implementation. + +.. module:: authlib.oauth2.rfc7592 + +.. _RFC7592: https://tools.ietf.org/html/rfc7592 + +Client Configuration Endpoint +----------------------------- + +Before register the endpoint, developers MUST implement the missing methods:: + + from authlib.oauth2.rfc7592 import ClientConfigurationEndpoint + + + class MyClientConfigurationEndpoint(ClientConfigurationEndpoint): + def authenticate_token(self, request): + # this method is used to authenticate the registration access + # token returned by the RFC7591 registration endpoint + auth_header = request.headers.get('Authorization') + bearer_token = auth_header.split()[1] + token = Token.get(bearer_token) + return token + + def authenticate_client(self, request): + client_id = request.payload.data.get('client_id') + return Client.get(client_id=client_id) + + def revoke_access_token(self, token, request): + token.revoked = True + token.save() + + def check_permission(self, client, request): + return client.editable + + def delete_client(self, client, request): + client.delete() + + def save_client(self, client_info, client_metadata, request): + client = OAuthClient( + user_id=request.credential.user_id, + client_id=client_info['client_id'], + client_secret=client_info['client_secret'], + **client_metadata, + ) + client.save() + return client + + def generate_client_registration_info(self, client, request): + access_token = request.headers['Authorization'].split(' ')[1] + return { + 'registration_client_uri': request.uri, + 'registration_access_token': access_token, + } + + def get_server_metadata(self): + return { + 'issuer': ..., + 'authorization_endpoint': ..., + 'token_endpoint': ..., + 'jwks_uri': ..., + 'registration_endpoint': ..., + 'scopes_supported': ..., + 'response_types_supported': ..., + 'response_modes_supported': ..., + 'grant_types_supported': ..., + 'token_endpoint_auth_methods_supported': ..., + 'token_endpoint_auth_signing_alg_values_supported': ..., + 'service_documentation': ..., + 'ui_locales_supported': ..., + 'op_policy_uri': ..., + 'op_tos_uri': ..., + 'revocation_endpoint': ..., + 'revocation_endpoint_auth_methods_supported': ..., + 'revocation_endpoint_auth_signing_alg_values_supported': ..., + 'introspection_endpoint': ..., + 'introspection_endpoint_auth_methods_supported': ..., + 'introspection_endpoint_auth_signing_alg_values_supported': ..., + 'code_challenge_methods_supported': ..., + } + +API Reference +------------- + +.. autoclass:: ClientConfigurationEndpoint + :member-order: bysource + :members: diff --git a/docs/specs/rfc7636.rst b/docs/oauth2/specs/rfc7636.rst similarity index 88% rename from docs/specs/rfc7636.rst rename to docs/oauth2/specs/rfc7636.rst index 6a69704f6..cc00030d7 100644 --- a/docs/specs/rfc7636.rst +++ b/docs/oauth2/specs/rfc7636.rst @@ -43,13 +43,13 @@ method:: def save_authorization_code(self, code, request): # NOTICE BELOW - code_challenge = request.data.get('code_challenge') - code_challenge_method = request.data.get('code_challenge_method') + code_challenge = request.payload.data.get('code_challenge') + code_challenge_method = request.payload.data.get('code_challenge_method') auth_code = AuthorizationCode( code=code, client_id=request.client.client_id, redirect_uri=request.redirect_uri, - scope=request.scope, + scope=request.payload.scope, user_id=request.user.id, code_challenge=code_challenge, code_challenge_method=code_challenge_method, @@ -63,7 +63,7 @@ Now you can register your ``AuthorizationCodeGrant`` with the extension:: from authlib.oauth2.rfc7636 import CodeChallenge server.register_grant(MyAuthorizationCodeGrant, [CodeChallenge(required=True)]) -If ``required=True``, code challenge is required for authorization code flow. +If ``required=True``, code challenge is required for authorization code flow from public clients. If ``required=False``, it is optional, it will only valid the code challenge when clients send these parameters. @@ -83,11 +83,6 @@ consider that we already have a ``session``:: >>> authorization_response = 'https://example.com/auth?code=42..e9&state=d..t' >>> token = session.fetch_token(token_endpoint, authorization_response=authorization_response, code_verifier=code_verifier) -The authorization flow is the same as in :ref:`oauth_2_session`, what you need -to do is: - -1. adding ``code_challenge`` and ``code_challenge_method`` in ``.create_authorization_url``. -2. adding ``code_verifier`` in ``.fetch_token``. API Reference ------------- diff --git a/docs/specs/rfc7662.rst b/docs/oauth2/specs/rfc7662.rst similarity index 62% rename from docs/specs/rfc7662.rst rename to docs/oauth2/specs/rfc7662.rst index 05bcd32f4..e3877fa6d 100644 --- a/docs/specs/rfc7662.rst +++ b/docs/oauth2/specs/rfc7662.rst @@ -20,6 +20,8 @@ a token. Register Introspection Endpoint ------------------------------- +.. versionchanged:: v1.0 + Authlib is designed to be very extendable, with the method of ``.register_endpoint`` on ``AuthorizationServer``, it is easy to add the introspection endpoint to the authorization server. It works on both @@ -29,7 +31,7 @@ we need to implement the missing methods:: from authlib.oauth2.rfc7662 import IntrospectionEndpoint class MyIntrospectionEndpoint(IntrospectionEndpoint): - def query_token(self, token, token_type_hint, client): + def query_token(self, token, token_type_hint): if token_type_hint == 'access_token': tok = Token.query.filter_by(access_token=token).first() elif token_type_hint == 'refresh_token': @@ -39,11 +41,7 @@ we need to implement the missing methods:: tok = Token.query.filter_by(access_token=token).first() if not tok: tok = Token.query.filter_by(refresh_token=token).first() - if tok: - if tok.client_id == client.client_id: - return tok - if has_introspect_permission(client): - return tok + return tok def introspect_token(self, token): return { @@ -59,6 +57,10 @@ we need to implement the missing methods:: 'iat': token.issued_at, } + def check_permission(self, token, client, request): + # for example, we only allow internal client to access introspection endpoint + return client.client_type == 'internal' + # register it to authorization server server.register_endpoint(MyIntrospectionEndpoint) @@ -69,6 +71,37 @@ After the registration, we can create a response with:: return server.create_endpoint_response(MyIntrospectionEndpoint.ENDPOINT_NAME) +.. _require_oauth_introspection: + +Use Introspection in Resource Server +------------------------------------ + +.. versionadded:: v1.0 + +When resource server has no access to token database, it can use introspection +endpoint to validate the given token. Here is how:: + + import requests + from authlib.oauth2.rfc7662 import IntrospectTokenValidator + from your_project import secrets + + class MyIntrospectTokenValidator(IntrospectTokenValidator): + def introspect_token(self, token_string): + url = 'https://example.com/oauth/introspect' + data = {'token': token_string, 'token_type_hint': 'access_token'} + auth = (secrets.internal_client_id, secrets.internal_client_secret) + resp = requests.post(url, data=data, auth=auth) + resp.raise_for_status() + return resp.json() + +We can then register this token validator in to resource protector:: + + require_oauth = ResourceProtector() + require_oauth.register_token_validator(MyIntrospectTokenValidator()) + +Please note, when using ``IntrospectTokenValidator``, the ``current_token`` will be +a dict. + API Reference ------------- @@ -76,3 +109,6 @@ API Reference :member-order: bysource :members: :inherited-members: + +.. autoclass:: IntrospectTokenValidator + :members: diff --git a/docs/oauth2/specs/rfc8414.rst b/docs/oauth2/specs/rfc8414.rst new file mode 100644 index 000000000..6e0be71df --- /dev/null +++ b/docs/oauth2/specs/rfc8414.rst @@ -0,0 +1,14 @@ +.. _specs/rfc8414: + +RFC8414: OAuth 2.0 Authorization Server Metadata +================================================ + +.. module:: authlib.oauth2.rfc8414 + + +API Reference +------------- + +.. autoclass:: AuthorizationServerMetadata + :member-order: bysource + :members: diff --git a/docs/specs/rfc8628.rst b/docs/oauth2/specs/rfc8628.rst similarity index 100% rename from docs/specs/rfc8628.rst rename to docs/oauth2/specs/rfc8628.rst diff --git a/docs/oauth2/specs/rfc9068.rst b/docs/oauth2/specs/rfc9068.rst new file mode 100644 index 000000000..466c7ff5d --- /dev/null +++ b/docs/oauth2/specs/rfc9068.rst @@ -0,0 +1,66 @@ +.. _specs/rfc9068: + +RFC9068: JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens +================================================================= + +This section contains the generic implementation of RFC9068_. +JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens allows +developers to generate JWT access tokens. + +Using JWT instead of plain text for access tokens result in different +possibilities: + +- User information can be filled in the JWT claims, similar to the + :ref:`specs/oidc` ``id_token``, possibly making the economy of + requests to the ``userinfo_endpoint``. +- Resource servers do not *need* to reach the authorization server + :ref:`specs/rfc7662` endpoint to verify each incoming tokens, as + the JWT signature is a proof of its validity. This brings the economy + of one network request at each resource access. +- Consequently, the authorization server do not need to store access + tokens in a database. If a resource server does not implement this + spec and still need to reach the authorization server introspection + endpoint to check the token validation, then the authorization server + can simply validate the JWT without requesting its database. +- If the authorization server do not store access tokens in a database, + it won't have the possibility to revoke the tokens. The produced access + tokens will be valid until the timestamp defined in its ``exp`` claim + is reached. + +This specification is just about **access** tokens. Other kinds of tokens +like refresh tokens are not covered. + +RFC9068_ define a few optional JWT claims inspired from RFC7643_ that can +can be used to determine if the token bearer is authorized to access a +resource: ``groups``, ``roles`` and ``entitlements``. + +This module brings tools to: + +- generate JWT access tokens with :class:`~authlib.oauth2.rfc9068.JWTBearerTokenGenerator` +- protected resources endpoints and validate JWT access tokens with :class:`~authlib.oauth2.rfc9068.JWTBearerTokenValidator` +- introspect JWT access tokens with :class:`~authlib.oauth2.rfc9068.JWTIntrospectionEndpoint` +- deny JWT access tokens revokation attempts with :class:`~authlib.oauth2.rfc9068.JWTRevocationEndpoint` + +.. _RFC9068: https://www.rfc-editor.org/rfc/rfc9068.html +.. _RFC7643: https://tools.ietf.org/html/rfc7643 + +API Reference +------------- + +.. module:: authlib.oauth2.rfc9068 + +.. autoclass:: JWTBearerTokenGenerator + :member-order: bysource + :members: + +.. autoclass:: JWTBearerTokenValidator + :member-order: bysource + :members: + +.. autoclass:: JWTIntrospectionEndpoint + :member-order: bysource + :members: + +.. autoclass:: JWTRevocationEndpoint + :member-order: bysource + :members: diff --git a/docs/oauth2/specs/rfc9101.rst b/docs/oauth2/specs/rfc9101.rst new file mode 100644 index 000000000..a6db00b55 --- /dev/null +++ b/docs/oauth2/specs/rfc9101.rst @@ -0,0 +1,37 @@ +.. _specs/rfc9101: + +RFC9101: The OAuth 2.0 Authorization Framework: JWT-Secured Authorization Request (JAR) +======================================================================================= + +This section contains the generic implementation of :rfc:`RFC9101 <9101>`. + +This specification describe how to pass the authorization request payload +in a JWT (called *request object*) instead of directly instead of GET or POST params. + +The request object can either be passed directly in a ``request`` parameter, +or be hosted by the client and be passed by reference with a ``request_uri`` +parameter. + +This usage is more secure than passing the request payload directly in the request, +read the RFC to know all the details. + +Request objects are optional, unless it is enforced by clients with the +``require_signed_request_object`` client metadata, or server-wide with the +``require_signed_request_object`` server metadata. + +API Reference +------------- + +.. module:: authlib.oauth2.rfc9101 + +.. autoclass:: JWTAuthenticationRequest + :member-order: bysource + :members: + +.. autoclass:: ClientMetadataClaims + :member-order: bysource + :members: + +.. autoclass:: AuthorizationServerMetadata + :member-order: bysource + :members: diff --git a/docs/oauth2/specs/rfc9207.rst b/docs/oauth2/specs/rfc9207.rst new file mode 100644 index 000000000..ba6796cb3 --- /dev/null +++ b/docs/oauth2/specs/rfc9207.rst @@ -0,0 +1,30 @@ +.. _specs/rfc9207: + +RFC9207: OAuth 2.0 Authorization Server Issuer Identification +============================================================= + +This section contains the generic implementation of :rfc:`RFC9207 <9207>`. + +In summary, RFC9207 advise to return an ``iss`` parameter in authorization code responses. +This can simply be done by implementing the :meth:`~authlib.oauth2.rfc9207.parameter.IssuerParameter.get_issuer` method in the :class:`~authlib.oauth2.rfc9207.parameter.IssuerParameter` class, +and pass it as a :class:`~authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant` extension:: + + from authlib.oauth2 import rfc9207 + + class IssuerParameter(rfc9207.IssuerParameter): + def get_issuer(self) -> str: + return "https://auth.example.org" + + ... + + authorization_server.register_extension(IssuerParameter()) + +API Reference +------------- + +.. module:: authlib.oauth2.rfc9207 + +.. autoclass:: IssuerParameter + :member-order: bysource + :members: + diff --git a/docs/oauth2/specs/rpinitiated.rst b/docs/oauth2/specs/rpinitiated.rst new file mode 100644 index 000000000..6adce1135 --- /dev/null +++ b/docs/oauth2/specs/rpinitiated.rst @@ -0,0 +1,174 @@ +.. _specs/rpinitiated: + +OpenID Connect RP-Initiated Logout 1.0 +====================================== + +.. meta:: + :description: Python API references on OpenID Connect RP-Initiated Logout 1.0 + EndSessionEndpoint with Authlib implementation. + +.. module:: authlib.oidc.rpinitiated + +This section contains the generic implementation of `OpenID Connect RP-Initiated +Logout 1.0`_. This specification enables Relying Parties (RPs) to request that +an OpenID Provider (OP) log out the End-User. + +.. _OpenID Connect RP-Initiated Logout 1.0: https://openid.net/specs/openid-connect-rpinitiated-1_0.html + +End Session Endpoint +-------------------- + +To add RP-Initiated Logout support, create a subclass of :class:`EndSessionEndpoint` +and implement the required methods:: + + from authlib.oidc.rpinitiated import EndSessionEndpoint + + class MyEndSessionEndpoint(EndSessionEndpoint): + def get_server_jwks(self): + return load_jwks() + + def end_session(self, end_session_request): + # Terminate user session + session.clear() + + server.register_endpoint(MyEndSessionEndpoint) + +Then create a logout route. You have two options: + +**Non-interactive mode** (simple, no confirmation page):: + + @app.route('/logout', methods=['GET', 'POST']) + def logout(): + return ( + server.create_endpoint_response("end_session") + or render_template('logged_out.html') + ) + +**Interactive mode** (with confirmation page):: + + @app.route('/logout', methods=['GET', 'POST']) + def logout(): + try: + req = server.validate_endpoint_request("end_session") + except OAuth2Error as error: + return server.handle_error_response(None, error) + + # Show confirmation page on GET when no id_token_hint was provided + # User confirms by submitting the form (POST) + if req.needs_confirmation and request.method == 'GET': + return render_template('confirm_logout.html', client=req.client) + + return ( + server.create_endpoint_response("end_session", req) + or render_template('logged_out.html') + ) + +The ``create_endpoint_response`` method returns ``None`` when there is no +``post_logout_redirect_uri``, allowing you to provide your own response page. + +Request Parameters +~~~~~~~~~~~~~~~~~~ + +The endpoint accepts the following parameters (via GET or POST): + +- **id_token_hint** (Recommended): A previously issued ID Token passed as a hint + about the End-User's authenticated session. +- **logout_hint** (Optional): A hint to the OP about the End-User that is logging out. +- **client_id** (Optional): The OAuth 2.0 Client Identifier. When both ``client_id`` + and ``id_token_hint`` are present, the OP verifies that the Client Identifier + matches the ``aud`` claim in the ID Token. +- **post_logout_redirect_uri** (Optional): URI to which the End-User's User Agent + is redirected after logout. Must exactly match a pre-registered value. +- **state** (Optional): Opaque value used by the RP to maintain state between the + logout request and the callback. +- **ui_locales** (Optional): End-User's preferred languages for the user interface. + +Confirmation Flow +~~~~~~~~~~~~~~~~~ + +Per the specification, logout requests without a valid ``id_token_hint`` are a +potential means of denial of service. The :attr:`EndSessionRequest.needs_confirmation` +property indicates when user confirmation is recommended. + +You control the confirmation page rendering - simply check ``needs_confirmation`` +and render your own template as shown in the interactive mode example above. + +Post-Logout Redirection +~~~~~~~~~~~~~~~~~~~~~~~ + +Post-logout redirection only happens when: + +1. A ``post_logout_redirect_uri`` is provided +2. The client is resolved (via ``id_token_hint`` or ``client_id``) +3. The URI is registered in the client's ``post_logout_redirect_uris`` + +If all conditions are met, ``EndSessionRequest.redirect_uri`` contains the +validated URI (with ``state`` appended if provided). + +If conditions are not met, ``create_endpoint_response`` returns ``None`` and +you should provide a default logout page:: + + server.create_endpoint_response("end_session", req) or render_template('logged_out.html') + +Session Validation +~~~~~~~~~~~~~~~~~~ + +When an ``id_token_hint`` is provided, the ``id_token_claims`` attribute of +:class:`EndSessionRequest` contains all claims from the ID Token, including +``sid`` (session ID) if present. + +Per the specification, you SHOULD verify that the ``sid`` matches the current +session to detect potentially suspect logout requests:: + + def end_session(self, end_session_request): + if end_session_request.id_token_claims: + sid = end_session_request.id_token_claims.get("sid") + if sid and sid != get_current_session_id(): + # Treat as suspect - may require additional confirmation + pass + session.clear() + +Client Registration +------------------- + +Relying Parties can register their ``post_logout_redirect_uris`` through +:ref:`RFC7591: OAuth 2.0 Dynamic Client Registration Protocol `. + +To support RP-Initiated Logout client metadata, add the claims class to your +registration and configuration endpoints:: + + from authlib import oidc + from authlib.oauth2 import rfc7591 + + authorization_server.register_endpoint( + ClientRegistrationEndpoint( + claims_classes=[ + rfc7591.ClientMetadataClaims, + oidc.registration.ClientMetadataClaims, + oidc.rpinitiated.ClientMetadataClaims, + ] + ) + ) + +The ``post_logout_redirect_uris`` parameter is an array of URLs to which the +End-User's User Agent may be redirected after logout. These URLs SHOULD use +the ``https`` scheme. + +API Reference +------------- + +.. autoclass:: EndSessionEndpoint + :member-order: bysource + :members: + +.. autoclass:: EndSessionRequest + :member-order: bysource + :members: + +.. autoclass:: ClientMetadataClaims + :member-order: bysource + :members: + +.. autoclass:: OpenIDProviderMetadata + :member-order: bysource + :members: diff --git a/docs/specs/index.rst b/docs/specs/index.rst deleted file mode 100644 index 87d8943d5..000000000 --- a/docs/specs/index.rst +++ /dev/null @@ -1,28 +0,0 @@ -Specifications -============== - -Guide on specifications. You don't have to read this section if you are -just using Authlib. But it would be good for you to understand how Authlib -works. - -.. toctree:: - :maxdepth: 2 - - rfc5849 - rfc6749 - rfc6750 - rfc7009 - rfc7515 - rfc7516 - rfc7517 - rfc7518 - rfc7519 - rfc7523 - rfc7591 - rfc7636 - rfc7638 - rfc7662 - rfc8037 - rfc8414 - rfc8628 - oidc diff --git a/docs/specs/oidc.rst b/docs/specs/oidc.rst deleted file mode 100644 index d767dc609..000000000 --- a/docs/specs/oidc.rst +++ /dev/null @@ -1,55 +0,0 @@ -.. _specs/oidc: - -OpenID Connect 1.0 -================== - -.. meta:: - :description: General implementation of OpenID Connect 1.0 in Python. - Learn how to create a OpenID Connect provider in Python. - -This part of the documentation covers the specification of OpenID Connect. Learn -how to use it in :ref:`flask_oidc_server` and :ref:`django_oidc_server`. - -OpenID Grants -------------- - -.. module:: authlib.oidc.core.grants - -.. autoclass:: OpenIDCode - :show-inheritance: - :members: - -.. autoclass:: OpenIDImplicitGrant - :show-inheritance: - :members: - -.. autoclass:: OpenIDHybridGrant - :show-inheritance: - :members: - -OpenID Claims -------------- - -.. module:: authlib.oidc.core - -.. autoclass:: IDToken - :show-inheritance: - :members: - - -.. autoclass:: CodeIDToken - :show-inheritance: - :members: - - -.. autoclass:: ImplicitIDToken - :show-inheritance: - :members: - - -.. autoclass:: HybridIDToken - :show-inheritance: - :members: - -.. autoclass:: UserInfo - :members: diff --git a/docs/specs/rfc8414.rst b/docs/specs/rfc8414.rst deleted file mode 100644 index 7455816ba..000000000 --- a/docs/specs/rfc8414.rst +++ /dev/null @@ -1,32 +0,0 @@ -.. _specs/rfc8414: - -RFC8414: OAuth 2.0 Authorization Server Metadata -================================================ - -.. module:: authlib.oauth2.rfc8414 - -:class:`AuthorizationServerMetadata` is enabled by default in -framework integrations: - -1. :ref:`flask_oauth2_server` -2. :ref:`django_oauth2_server` - -Configuration -------------- - -In :ref:`flask_oauth2_server`, config with:: - - OAUTH2_METADATA_FILE = '/www/.well-known/oauth-authorization-server' - -In :ref:`django_oauth2_server`, add into settings:: - - AUTHLIB_OAUTH2_PROVIDER = { - 'metadata_file': '/www/.well-known/oauth-authorization-server' - } - -API Reference -------------- - -.. autoclass:: AuthorizationServerMetadata - :member-order: bysource - :members: diff --git a/docs/upgrades/changelog.rst b/docs/upgrades/changelog.rst new file mode 100644 index 000000000..6ed36a7c5 --- /dev/null +++ b/docs/upgrades/changelog.rst @@ -0,0 +1,358 @@ +Changelog +========= + +.. meta:: + :description: The full list of changes between each Authlib release. + +Here you can see the full list of changes between each Authlib release. + +Version 1.7.0 +------------- + +**Released on Apr 18, 2026** + +- Add support for `OpenID Connect RP-Initiated Logout 1.0 + `_. + See :ref:`specs/rpinitiated` for details. :issue:`500` +- Per RFC 6749 Section 3.3, the ``scope`` parameter is now optional at both + authorization and token endpoints. ``client.get_allowed_scope()`` is called + to determine the default scope when omitted. :issue:`845` +- Stop support for Python 3.9, start support Python 3.14. :pr:`850` +- Allow ``AuthorizationServerMetadata.validate()`` to compose with RFC extension classes. +- Fix ``expires_at=0`` being incorrectly treated as ``None``. :issue:`530` +- Allow ``ResourceProtector`` decorator to be used without parentheses. :issue:`604` +- Implement RFC9700 PKCE downgrade countermeasure. +- Set ``User-Agent`` header when fetching server metadata and JWKs. :issue:`704` +- RFC7523 accepts the issuer URL as a valid audience. :issue:`730` +- Fix ``InvalidTokenError`` extra attributes being wrapped instead of passed as + individual key=value pairs in the ``WWW-Authenticate`` header. :pr:`872` + +Upgrade Guide: :ref:`joserfc_upgrade`. + +Version 1.6.11 +-------------- + +**Released on Apr 16, 2026** + +- Fix CSRF vulnerability in the Starlette OAuth client when a ``cache`` is + configured. + +Version 1.6.10 +-------------- + +**Released on Apr 13, 2026** + +- Fix redirecting to unvalidated ``redirect_uri`` on ``UnsupportedResponseTypeError``. + +Version 1.6.9 +------------- + +**Released on Mar 2, 2026** + +- Not using header's ``jwk`` automatically. +- Add ``ES256K`` into default jwt algorithms. +- Remove deprecated algorithm from default registry. +- Generate random ``cek`` when ``cek`` length doesn't match. + +Version 1.6.8 +------------- + +**Released on Feb 17, 2026** + +- Add ``EdDSA`` to default ``jwt`` instance. + +Version 1.6.7 +------------- + +**Released on Feb 6, 2026** + +- Set supported algorithms for the default ``jwt`` instance. + +Version 1.6.6 +------------- + +**Released on Jan 9, 2026** + +- ``get_jwt_config`` takes a ``client`` parameter, :pr:`844`. +- Fix incorrect signature when ``Content-Type`` is x-www-form-urlencoded for OAuth 1.0 Client, :pr:`778`. +- Use ``expires_in`` in ``OAuth2Token`` when ``expires_at`` is unparsable, :pr:`842`. +- Always track ``state`` in session for OAuth client integrations. + +Version 1.6.5 +------------- + +**Released on Oct 2, 2025** + +- RFC7591 ``generate_client_info`` and ``generate_client_secret`` take a ``request`` parameter. +- Add size limitation when decode JWS/JWE to prevent DoS. +- Add size limitation for ``DEF`` JWE zip algorithm. + +Version 1.6.4 +------------- + +**Released on Sep 17, 2025** + +- Fix ``InsecureTransportError`` error raising. :issue:`795` +- Fix ``response_mode=form_post`` with Starlette client. :issue:`793` +- Validate ``crit`` header value, reject unprotected header in ``crit`` header. + +Version 1.6.3 +------------- + +**Released on Aug 26, 2025** + +- OIDC ``id_token`` are signed according to ``id_token_signed_response_alg`` + client metadata. :issue:`755` + +Version 1.6.2 +------------- + +**Released on Aug 23, 2025** + +- Temporarily restore ``OAuth2Request`` ``body`` parameter. :issue:`781` :pr:`791` +- Allow ``127.0.0.1`` in insecure transport mode. :pr:`788` +- Raise ``MissingCodeException`` when the ``code`` parameter is missing. :issue:`793` :pr:`794` +- Fix ``id_token`` generation with `EdDSA` algs. :issue:`799` :pr:`800` + +Version 1.6.1 +------------- + +**Released on Jul 20, 2025** + +- Filter key set with additional "alg" and "use" parameters. +- Restore and deprecate ``OAuth2Request`` ``body`` parameter. :issue:`781` + +Version 1.6.0 +------------- + +**Released on May 22, 2025** + +- Fix issue when :rfc:`RFC9207 <9207>` is enabled and the authorization endpoint response is not a redirection. :pr:`733` +- Fix missing ``state`` parameter in authorization error responses. :issue:`525` +- Support for the ``none`` JWS algorithm. +- Fix ``response_types`` strict order during dynamic client registration. :issue:`760` +- Implement :rfc:`RFC9101 The OAuth 2.0 Authorization Framework: JWT-Secured Authorization Request (JAR) <9101>`. :issue:`723` +- OIDC :class:`UserInfo endpoint ` support. :issue:`459` + +**Breaking changes**: + +- Support for ``acr`` and ``amr`` claims in ``id_token``. :issue:`734` + The ``OAuth2AuthorizationCodeMixin`` must have a migration to support the new fields. + +Version 1.5.2 +------------- + +**Released on Apr 1, 2025** + +- Forbid fragments in ``redirect_uris``. :issue:`714` +- Fix invalid characters in ``error_description``. :issue:`720` +- Add ``claims_cls`` parameter for client's ``parse_id_token`` method. :issue:`725` + +Version 1.5.1 +------------- + +**Released on Feb 28, 2025** + +- Fix RFC9207 ``iss`` parameter. :pr:`715` + +Version 1.5.0 +------------- + +**Released on Feb 25, 2025** + +- Fix token introspection auth method for clients. :pr:`662` +- Optional ``typ`` claim in JWT tokens. :pr:`696` +- JWT validation leeway. :pr:`689` +- Implement server-side :rfc:`RFC9207 <9207>`. :issue:`700` :pr:`701` +- ``generate_id_token`` can take a ``kid`` parameter. :pr:`702` +- More detailed ``InvalidClientError``. :pr:`706` +- OpenID Connect Dynamic Client Registration implementation. :pr:`707` + +Version 1.4.1 +------------- + +**Released on Jan 28, 2025** + +- Improve garbage collection on OAuth clients. :issue:`698` +- Fix client parameters for httpx. :issue:`694` + +Version 1.4.0 +------------- + +**Released on Dec 20, 2024** + +- Fix ``id_token`` decoding when kid is null. :pr:`659` +- Support for Python 3.13. :pr:`682` +- Force login if the ``prompt`` parameter value is ``login``. :pr:`637` +- Support for httpx 0.28, :pr:`695` + +**Breaking changes**: + +- Stop support for Python 3.8. :pr:`682` + +Version 1.3.2 +------------- + +**Released on Aug 30 2024** + +- Prevent ever-growing session size for OAuth clients. +- Revert ``quote`` client id and secret. +- ``unquote`` basic auth header for authorization server. + +Version 1.3.1 +------------- + +**Released on June 4, 2024** + +- Prevent ``OctKey`` to import ssh and PEM strings. + + +Version 1.3.0 +------------- + +**Released on Dec 17, 2023** + +- Restore ``AuthorizationServer.create_authorization_response`` behavior, via :PR:`558` +- Include ``leeway`` in ``validate_iat()`` for JWT, via :PR:`565` +- Fix ``encode_client_secret_basic``, via :PR:`594` +- Use single key in JWK if JWS does not specify ``kid``, via :PR:`596` +- Fix error when RFC9068 JWS has no scope field, via :PR:`598` +- Get werkzeug version using importlib, via :PR:`591` + +**New features**: + +- RFC9068 implementation, via :PR:`586`, by @azmeuk. + +**Breaking changes**: + +- End support for python 3.7 + +Version 1.2.1 +------------- + +**Released on Jun 25, 2023** + +- Apply headers in ``ClientSecretJWT.sign`` method, via :PR:`552` +- Allow falsy but non-None grant uri params, via :PR:`544` +- Fixed ``authorize_redirect`` for Starlette v0.26.0, via :PR:`533` +- Removed ``has_client_secret`` method and documentation, via :PR:`513` +- Removed ``request_invalid`` and ``token_revoked`` remaining occurrences + and documentation. :PR:`514` +- Fixed RFC7591 ``grant_types`` and ``response_types`` default values, via :PR:`509`. +- Add support for python 3.12, via :PR:`590`. + +Version 1.2.0 +------------- + +**Released on Dec 6, 2022** + +- Not passing ``request.body`` to ``ResourceProtector``, via :issue:`485`. +- Use ``flask.g`` instead of ``_app_ctx_stack``, via :issue:`482`. +- Add ``headers`` parameter back to ``ClientSecretJWT``, via :issue:`457`. +- Always passing ``realm`` parameter in OAuth 1 clients, via :issue:`339`. +- Implemented RFC7592 Dynamic Client Registration Management Protocol, via :PR:`505`. +- Add ``default_timeout`` for requests ``OAuth2Session`` and ``AssertionSession``. +- Deprecate ``jwk.loads`` and ``jwk.dumps`` + +Version 1.1.0 +------------- + +**Released on Sep 13, 2022** + +This release contains breaking changes and security fixes. + +- Allow to pass ``claims_options`` to Framework OpenID Connect clients, via :PR:`446`. +- Fix ``.stream`` with context for HTTPX OAuth clients, via :PR:`465`. +- Fix Starlette OAuth client for cache store, via :PR:`478`. + +**Breaking changes**: + +- Raise ``InvalidGrantError`` for invalid code, redirect_uri and no user errors in OAuth + 2.0 server. +- The default ``authlib.jose.jwt`` would only work with JSON Web Signature algorithms, if + you would like to use JWT with JWE algorithms, please pass the algorithms parameter:: + + jwt = JsonWebToken(['A128KW', 'A128GCM', 'DEF']) + +**Security fixes**: CVE-2022-39175 and CVE-2022-39174, both related to JOSE. + +Version 1.0.1 +------------- + +**Released on Apr 6, 2022** + +- Fix authenticate_none method, via :issue:`438`. +- Allow to pass in alternative signing algorithm to RFC7523 authentication methods via :PR:`447`. +- Fix ``missing_token`` for Flask OAuth client, via :issue:`448`. +- Allow ``openid`` in any place of the scope, via :issue:`449`. +- Security fix for validating essential value on blank value in JWT, via :issue:`445`. + + +Version 1.0.0 +------------- + +**Released on Mar 15, 2022.** + +We have dropped support for Python 2 in this release. We have removed +built-in SQLAlchemy integration. + +**OAuth Client Changes:** + +The whole framework client integrations have been restructured, if you are +using the client properly, e.g. ``oauth.register(...)``, it would work as +before. + +**OAuth Provider Changes:** + +In Flask OAuth 2.0 provider, we have removed the deprecated +``OAUTH2_JWT_XXX`` configuration, instead, developers should define +`.get_jwt_config` on OpenID extensions and grant types. + +**SQLAlchemy** integrations has been removed from Authlib. Developers +should define the database by themselves. + +**JOSE Changes** + +- ``JWS`` has been renamed to ``JsonWebSignature`` +- ``JWE`` has been renamed to ``JsonWebEncryption`` +- ``JWK`` has been renamed to ``JsonWebKey`` +- ``JWT`` has been renamed to ``JsonWebToken`` + +The "Key" model has been re-designed, checkout the :ref:`jwk_guide` for updates. + +Added ``ES256K`` algorithm for JWS and JWT. + +**Breaking Changes**: find how to solve the deprecate issues via https://git.io/JkY4f + + +Old Versions +------------ + +Find old changelog at https://github.com/authlib/authlib/releases + +- Version 0.15.5: Released on Oct 18, 2021 +- Version 0.15.4: Released on Jul 17, 2021 +- Version 0.15.3: Released on Jan 15, 2021 +- Version 0.15.2: Released on Oct 18, 2020 +- Version 0.15.1: Released on Oct 14, 2020 +- Version 0.15.0: Released on Oct 10, 2020 +- Version 0.14.3: Released on May 18, 2020 +- Version 0.14.2: Released on May 6, 2020 +- Version 0.14.1: Released on Feb 12, 2020 +- Version 0.14.0: Released on Feb 11, 2020 +- Version 0.13.0: Released on Nov 11, 2019 +- Version 0.12.0: Released on Sep 3, 2019 +- Version 0.11.0: Released on Apr 6, 2019 +- Version 0.10.0: Released on Oct 12, 2018 +- Version 0.9.0: Released on Aug 12, 2018 +- Version 0.8.0: Released on Jun 17, 2018 +- Version 0.7.0: Released on Apr 28, 2018 +- Version 0.6.0: Released on Mar 20, 2018 +- Version 0.5.1: Released on Feb 11, 2018 +- Version 0.5.0: Released on Feb 11, 2018 +- Version 0.4.1: Released on Feb 2, 2018 +- Version 0.4.0: Released on Jan 31, 2018 +- Version 0.3.0: Released on Dec 24, 2017 +- Version 0.2.1: Released on Dec 6, 2017 +- Version 0.2.0: Released on Nov 25, 2017 +- Version 0.1.0: Released on Nov 18, 2017 diff --git a/docs/upgrades/index.rst b/docs/upgrades/index.rst new file mode 100644 index 000000000..e5cad35fc --- /dev/null +++ b/docs/upgrades/index.rst @@ -0,0 +1,8 @@ +Releases +======== + +.. toctree:: + :maxdepth: 2 + + changelog + jose diff --git a/docs/upgrades/jose.rst b/docs/upgrades/jose.rst new file mode 100644 index 000000000..ef2c6e07b --- /dev/null +++ b/docs/upgrades/jose.rst @@ -0,0 +1,114 @@ +.. _joserfc_upgrade: + +joserfc migration +================= + +joserfc_ is derived from Authlib and provides a cleaner design along with +first-class type hints. We strongly recommend using ``joserfc`` instead of +the ``authlib.jose`` module. + +Starting with **Authlib 1.7.0**, the ``authlib.jose`` module is deprecated and +will emit deprecation warnings. A comprehensive +`Migrating from Authlib `_ +guide is available in the joserfc_ documentation to help you transition. + +.. _joserfc: https://jose.authlib.org/en/ + +The following modules are affected by this upgrade: + +- ``authlib.oauth2.rfc7523`` +- ``authlib.oauth2.rfc7591`` +- ``authlib.oauth2.rfc7592`` +- ``authlib.oauth2.rfc9068`` +- ``authlib.oauth2.rfc9101`` +- ``authlib.oidc.core`` + +Breaking Changes +---------------- + +A common breaking change involves the exceptions raised by the affected modules. +Since these modules now use ``joserfc``, all exceptions are ``joserfc``-based. +If your code previously caught exceptions from ``authlib.jose``, you should +update it to catch the corresponding exceptions from ``joserfc`` instead. + +.. code-block:: diff + + -from authlib.jose.errors import JoseError + +from joserfc.errors import JoseError + + try: + do_something() + except JoseError: + pass + +JWTAuthenticationRequest +~~~~~~~~~~~~~~~~~~~~~~~~ + + +Starting with v1.7, ``authlib.oauth2.rfc9101.JWTAuthenticationRequest`` uses +only the recommended JWT algorithms by default. If you need to support additional +algorithms, you can explicitly include them in ``get_server_metadata``: + +.. code-block:: python + + class MyJWTAuthenticationRequest(JWTAuthenticationRequest): + def get_server_metadata(self): + return { + ..., + "request_object_signing_alg_values_supported": ["RS256", ...], + } + + +UserInfoEndpoint +~~~~~~~~~~~~~~~~ + +The signing algorithms supported by ``authlib.oidc.core.UserInfoEndpoint`` are +limited to the recommended JWT algorithms. If you need to support additional +algorithms, you can explicitly include them in ``get_supported_algorithms``: + +.. code-block:: python + + class MyUserInfoEndpoint(UserInfoEndpoint): + def get_supported_algorithms(self): + return ["RS512"] + +Deprecating Messages +-------------------- + +Most deprecation warnings are triggered by how keys are imported. For security +reasons, joserfc_ requires explicit key types. Instead of passing raw strings or +bytes as keys, you should return ``OctKey``, ``RSAKey``, ``ECKey``, ``OKPKey``, +or ``KeySet`` instances directly. + +``get_jwt_config`` +------------------ + +``get_jwt_config`` is converted into 3 methods: + +1. ``resolve_client_private_key`` +2. ``get_client_claims`` +3. ``get_client_algorithm`` + +.. code-block:: python + + # before 1.7 + class OpenIDCode(grants.OpenIDCode): + def get_jwt_config(self): + return { + 'key': read_private_key_file(key_path), + 'alg': 'RS512', + 'iss': 'https://example.com', + 'exp': 3600 + } + + # authlib>=1.7 + class OpenIDCode(grants.OpenIDCode): + def resolve_client_private_key(self, client): + with open(jwks_file_path) as f: + data = json.load(f) + return KeySet.import_key_set(data) + + def get_client_claims(self, client): + return { + 'iss': 'https://example.com', + } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..e365ae59b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,143 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "Authlib" +description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." +authors = [{name = "Hsiaoming Yang", email="me@lepture.com"}] +dependencies = [ + "cryptography", + "joserfc>=1.6.0", +] +license = {text = "BSD-3-Clause"} +requires-python = ">=3.10" +dynamic = ["version"] +readme = "README.md" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Security", + "Topic :: Security :: Cryptography", + "Topic :: Internet :: WWW/HTTP :: Dynamic Content", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", +] + +[project.urls] +Documentation = "https://docs.authlib.org/" +Purchase = "https://authlib.org/plans" +Issues = "https://github.com/authlib/authlib/issues" +Source = "https://github.com/authlib/authlib" +Donate = "https://github.com/sponsors/lepture" +Blog = "https://blog.authlib.org/" + +[dependency-groups] +dev = [ + "coverage", + "cryptography", + "diff-cover>=9.6.0", + "prek>=0.1.3", + "pytest", + "pytest-asyncio", + "pytest-env", + "tox-uv >= 1.16.0", +] + +clients = [ + "anyio", + "cachelib", + "django", + "flask", + "httpx", + "requests", + "starlette[full]", +] + +django = [ + "django", + "pytest-django", +] + +flask = [ + "Flask", + "Flask-SQLAlchemy", +] + +jose = [ + "pycryptodomex>=3.10,<4", +] + +docs = [ + "shibuya", + "sphinx", + "sphinx-design", + "sphinx-copybutton", +] + +[tool.setuptools.dynamic] +version = {attr = "authlib.__version__"} + +[tool.setuptools.packages.find] +where = ["."] +include = ["authlib", "authlib.*"] + +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "UP", # pyupgrade +] +ignore = [ + "E501", # line-too-long + "E722", # bare-except +] + +[tool.ruff.lint.isort] +force-single-line = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" +norecursedirs = ["authlib", "build", "dist", "docs", "htmlcov"] +pythonpath = ["."] +env = [ + "DJANGO_SETTINGS_MODULE = tests.django_settings", +] + +[tool.coverage.run] +branch = true + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "except ImportError", + "def __repr__", + "raise NotImplementedError", + "raise DeprecationWarning", + "deprecate", + "if TYPE_CHECKING:", +] + +[tool.check-manifest] +ignore = ["tox.ini"] + +[tool.distutils.bdist_wheel] +universal = true diff --git a/requirements-docs.txt b/requirements-docs.txt deleted file mode 100644 index 964d6aef8..000000000 --- a/requirements-docs.txt +++ /dev/null @@ -1,8 +0,0 @@ -cryptography -Flask -Django -SQLAlchemy -requests -httpx -starlette -sphinx-typlog-theme==0.8.0 diff --git a/requirements-test.txt b/requirements-test.txt deleted file mode 100644 index a96de9ff6..000000000 --- a/requirements-test.txt +++ /dev/null @@ -1,5 +0,0 @@ -cryptography -requests -mock -pytest -coverage diff --git a/serve.py b/serve.py new file mode 100644 index 000000000..a96711d26 --- /dev/null +++ b/serve.py @@ -0,0 +1,7 @@ +from livereload import Server +from livereload import shell + +app = Server() +# app.watch("src", shell("make build-docs"), delay=2) +app.watch("docs", shell("make build-docs"), delay=2) +app.serve(root="build/_html") diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 8ecb03444..000000000 --- a/setup.cfg +++ /dev/null @@ -1,33 +0,0 @@ -[bdist_wheel] -universal = 1 - -[metadata] -license_file = LICENSE - -[check-manifest] -ignore = - tox.ini - -[flake8] -exclude = - tests/* -max-line-length = 100 -max-complexity = 10 - -[tool:pytest] -DJANGO_SETTINGS_MODULE = tests.django.settings -python_files = test*.py -python_paths = tests -norecursedirs=authlib build dist docs htmlcov - -[coverage:run] -branch = True - -[coverage:report] -exclude_lines = - pragma: no cover - except ImportError - def __repr__ - raise NotImplementedError - raise DeprecationWarning - deprecate diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 6b6bf27df..78457defb --- a/setup.py +++ b/setup.py @@ -1,66 +1,10 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - - -from setuptools import setup, find_packages -from authlib.consts import version, homepage - - -with open('README.rst') as f: - readme = f.read() - - -client_requires = ['requests'] -crypto_requires = ['cryptography'] +from setuptools import setup +# Metadata goes in setup.cfg. These are here for GitHub's dependency graph. setup( - name='Authlib', - version=version, - author='Hsiaoming Yang', - author_email='me@lepture.com', - url=homepage, - packages=find_packages(include=('authlib', 'authlib.*')), - description=( - 'The ultimate Python library in building OAuth and ' - 'OpenID Connect servers.' - ), - zip_safe=False, - include_package_data=True, - platforms='any', - long_description=readme, - license='BSD-3-Clause', - install_requires=crypto_requires, - extras_require={ - 'client': client_requires, - }, - project_urls={ - 'Documentation': 'https://docs.authlib.org/', - 'Commercial License': 'https://authlib.org/plans', - 'Bug Tracker': 'https://github.com/lepture/authlib/issues', - 'Source Code': 'https://github.com/lepture/authlib', - 'Blog': 'https://blog.authlib.org/', - 'Donate': 'https://lepture.com/donate', - }, - classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Environment :: Web Environment', - 'Framework :: Flask', - 'Framework :: Django', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', - 'Topic :: Internet :: WWW/HTTP :: WSGI :: Application', - 'Topic :: Software Development :: Libraries :: Python Modules', - ] + name="Authlib", + install_requires=[ + "cryptography>=3.2", + ], ) diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 000000000..e05d4e462 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,9 @@ +sonar.projectKey=authlib_authlib +sonar.organization=authlib + +sonar.sources=authlib +sonar.sourceEncoding=UTF-8 +sonar.test.inclusions=tests/**/test_*.py + +sonar.python.version=3.10, 3.11, 3.12, 3.13, 3.14 +sonar.python.coverage.reportPaths=coverage.xml diff --git a/tests/core/test_jose/__init__.py b/tests/clients/__init__.py similarity index 100% rename from tests/core/test_jose/__init__.py rename to tests/clients/__init__.py diff --git a/tests/clients/asgi_helper.py b/tests/clients/asgi_helper.py new file mode 100644 index 000000000..c5275441c --- /dev/null +++ b/tests/clients/asgi_helper.py @@ -0,0 +1,67 @@ +import json + +from starlette.requests import Request as ASGIRequest +from starlette.responses import Response as ASGIResponse + + +class AsyncMockDispatch: + def __init__(self, body=b"", status_code=200, headers=None, assert_func=None): + if headers is None: + headers = {} + if isinstance(body, dict): + body = json.dumps(body).encode() + headers["Content-Type"] = "application/json" + else: + if isinstance(body, str): + body = body.encode() + headers["Content-Type"] = "application/x-www-form-urlencoded" + + self.body = body + self.status_code = status_code + self.headers = headers + self.assert_func = assert_func + + async def __call__(self, scope, receive, send): + request = ASGIRequest(scope, receive=receive) + + if self.assert_func: + await self.assert_func(request) + + response = ASGIResponse( + status_code=self.status_code, + content=self.body, + headers=self.headers, + ) + await response(scope, receive, send) + + +class AsyncPathMapDispatch: + def __init__(self, path_maps, side_effects=None): + self.path_maps = path_maps + self.side_effects = side_effects or dict() + + async def __call__(self, scope, receive, send): + request = ASGIRequest(scope, receive=receive) + + side_effect = self.side_effects.get(request.url.path) + if side_effect is not None: + side_effect(request) + + rv = self.path_maps[request.url.path] + status_code = rv.get("status_code", 200) + body = rv.get("body") + headers = rv.get("headers", {}) + if isinstance(body, dict): + body = json.dumps(body).encode() + headers["Content-Type"] = "application/json" + else: + if isinstance(body, str): + body = body.encode() + headers["Content-Type"] = "application/x-www-form-urlencoded" + + response = ASGIResponse( + status_code=status_code, + content=body, + headers=headers, + ) + await response(scope, receive, send) diff --git a/tests/clients/keys/jwks_private.json b/tests/clients/keys/jwks_private.json new file mode 100644 index 000000000..2b2149f80 --- /dev/null +++ b/tests/clients/keys/jwks_private.json @@ -0,0 +1,6 @@ +{ + "keys": [ + {"kty": "RSA", "kid": "abc", "n": "pF1JaMSN8TEsh4N4O_5SpEAVLivJyLH-Cgl3OQBPGgJkt8cg49oasl-5iJS-VdrILxWM9_JCJyURpUuslX4Eb4eUBtQ0x5BaPa8-S2NLdGTaL7nBOO8o8n0C5FEUU-qlEip79KE8aqOj-OC44VsIquSmOvWIQD26n3fCVlgwoRBD1gzzsDOeaSyzpKrZR851Kh6rEmF2qjJ8jt6EkxMsRNACmBomzgA4M1TTsisSUO87444pe35Z4_n5c735o2fZMrGgMwiJNh7rT8SYxtIkxngioiGnwkxGQxQ4NzPAHg-XSY0J04pNm7KqTkgtxyrqOANJLIjXlR-U9SQ90NjHVQ", "e": "AQAB", "d": "G4E84ppZwm3fLMI0YZ26iJ_sq3BKcRpQD6_r0o8ZrZmO7y4Uc-ywoP7h1lhFzaox66cokuloZpKOdGHIfK-84EkI3WeveWHPqBjmTMlN_ClQVcI48mUbLhD7Zeenhi9y9ipD2fkNWi8OJny8k4GfXrGqm50w8schrsPksnxJjvocGMT6KZNfDURKF2HlM5X1uY8VCofokXOjBEeHIfYM8e7IcmPpyXwXKonDmVVbMbefo-u-TttgeyOYaO6s3flSy6Y0CnpWi43JQ_VEARxQl6Brj1oizr8UnQQ0nNCOWwDNVtOV4eSl7PZoiiT7CxYkYnhJXECMAM5YBpm4Qk9zdQ", "p": "1g4ZGrXOuo75p9_MRIepXGpBWxip4V7B9XmO9WzPCv8nMorJntWBmsYV1I01aITxadHatO4Gl2xLniNkDyrEQzJ7w38RQgsVK-CqbnC0K9N77QPbHeC1YQd9RCNyUohOimKvb7jyv798FBU1GO5QI2eNgfnnfteSVXhD2iOoTOs", "q": "xJJ-8toxJdnLa0uUsAbql6zeNXGbUBMzu3FomKlyuWuq841jS2kIalaO_TRj5hbnE45jmCjeLgTVO6Ach3Wfk4zrqajqfFJ0zUg_Wexp49lC3RWiV4icBb85Q6bzeJD9Dn9vhjpfWVkczf_NeA1fGH_pcgfkT6Dm706GFFttLL8", "dp": "Zfx3l5NR-O8QIhzuHSSp279Afl_E6P0V2phdNa_vAaVKDrmzkHrXcl-4nPnenXrh7vIuiw_xkgnmCWWBUfylYALYlu-e0GGpZ6t2aIJIRa1QmT_CEX0zzhQcae-dk5cgHK0iO0_aUOOyAXuNPeClzAiVknz4ACZDsXdIlNFyaZs", "dq": "Z9DG4xOBKXBhEoWUPXMpqnlN0gPx9tRtWe2HRDkZsfu_CWn-qvEJ1L9qPSfSKs6ls5pb1xyeWseKpjblWlUwtgiS3cOsM4SI03H4o1FMi11PBtxKJNitLgvT_nrJ0z8fpux-xfFGMjXyFImoxmKpepLzg5nPZo6f6HscLNwsSJk", "qi": "Sk20wFvilpRKHq79xxFWiDUPHi0x0pp82dYIEntGQkKUWkbSlhgf3MAi5NEQTDmXdnB-rVeWIvEi-BXfdnNgdn8eC4zSdtF4sIAhYr5VWZo0WVWDhT7u2ccvZBFymiz8lo3gN57wGUCi9pbZqzV1-ZppX6YTNDdDCE0q-KO3Cec"}, + {"kty": "RSA", "kid": "bilbo.baggins@hobbiton.example", "use": "sig", "n": "n4EPtAOCc9AlkeQHPzHStgAbgs7bTZLwUBZdR8_KuKPEHLd4rHVTeT-O-XV2jRojdNhxJWTDvNd7nqQ0VEiZQHz_AJmSCpMaJMRBSFKrKb2wqVwGU_NsYOYL-QtiWN2lbzcEe6XC0dApr5ydQLrHqkHHig3RBordaZ6Aj-oBHqFEHYpPe7Tpe-OfVfHd1E6cS6M1FZcD1NNLYD5lFHpPI9bTwJlsde3uhGqC0ZCuEHg8lhzwOHrtIQbS0FVbb9k3-tVTU4fg_3L_vniUFAKwuCLqKnS2BYwdq_mzSnbLY7h_qixoR7jig3__kRhuaxwUkRz5iaiQkqgc5gHdrNP5zw", "e": "AQAB", "d": "bWUC9B-EFRIo8kpGfh0ZuyGPvMNKvYWNtB_ikiH9k20eT-O1q_I78eiZkpXxXQ0UTEs2LsNRS-8uJbvQ-A1irkwMSMkK1J3XTGgdrhCku9gRldY7sNA_AKZGh-Q661_42rINLRCe8W-nZ34ui_qOfkLnK9QWDDqpaIsA-bMwWWSDFu2MUBYwkHTMEzLYGqOe04noqeq1hExBTHBOBdkMXiuFhUq1BU6l-DqEiWxqg82sXt2h-LMnT3046AOYJoRioz75tSUQfGCshWTBnP5uDjd18kKhyv07lhfSJdrPdM5Plyl21hsFf4L_mHCuoFau7gdsPfHPxxjVOcOpBrQzwQ", "p": "3Slxg_DwTXJcb6095RoXygQCAZ5RnAvZlno1yhHtnUex_fp7AZ_9nRaO7HX_-SFfGQeutao2TDjDAWU4Vupk8rw9JR0AzZ0N2fvuIAmr_WCsmGpeNqQnev1T7IyEsnh8UMt-n5CafhkikzhEsrmndH6LxOrvRJlsPp6Zv8bUq0k", "q": "uKE2dh-cTf6ERF4k4e_jy78GfPYUIaUyoSSJuBzp3Cubk3OCqs6grT8bR_cu0Dm1MZwWmtdqDyI95HrUeq3MP15vMMON8lHTeZu2lmKvwqW7anV5UzhM1iZ7z4yMkuUwFWoBvyY898EXvRD-hdqRxHlSqAZ192zB3pVFJ0s7pFc", "dp": "B8PVvXkvJrj2L-GYQ7v3y9r6Kw5g9SahXBwsWUzp19TVlgI-YV85q1NIb1rxQtD-IsXXR3-TanevuRPRt5OBOdiMGQp8pbt26gljYfKU_E9xn-RULHz0-ed9E9gXLKD4VGngpz-PfQ_q29pk5xWHoJp009Qf1HvChixRX59ehik", "dq": "CLDmDGduhylc9o7r84rEUVn7pzQ6PF83Y-iBZx5NT-TpnOZKF1pErAMVeKzFEl41DlHHqqBLSM0W1sOFbwTxYWZDm6sI6og5iTbwQGIC3gnJKbi_7k_vJgGHwHxgPaX2PnvP-zyEkDERuf-ry4c_Z11Cq9AqC2yeL6kdKT1cYF8", "qi": "3PiqvXQN0zwMeE-sBvZgi289XP9XCQF3VWqPzMKnIgQp7_Tugo6-NZBKCQsMf3HaEGBjTVJs_jcK8-TRXvaKe-7ZMaQj8VfBdYkssbu0NKDDhjJ-GtiseaDVWt7dcH0cfwxgFUHpQh7FoCrjFJ6h6ZEpMF6xmujs4qMpPz8aaI4"} + ] +} diff --git a/tests/clients/keys/jwks_public.json b/tests/clients/keys/jwks_public.json new file mode 100644 index 000000000..e29644a62 --- /dev/null +++ b/tests/clients/keys/jwks_public.json @@ -0,0 +1,6 @@ +{ + "keys": [ + {"kty": "RSA", "kid": "abc", "n": "pF1JaMSN8TEsh4N4O_5SpEAVLivJyLH-Cgl3OQBPGgJkt8cg49oasl-5iJS-VdrILxWM9_JCJyURpUuslX4Eb4eUBtQ0x5BaPa8-S2NLdGTaL7nBOO8o8n0C5FEUU-qlEip79KE8aqOj-OC44VsIquSmOvWIQD26n3fCVlgwoRBD1gzzsDOeaSyzpKrZR851Kh6rEmF2qjJ8jt6EkxMsRNACmBomzgA4M1TTsisSUO87444pe35Z4_n5c735o2fZMrGgMwiJNh7rT8SYxtIkxngioiGnwkxGQxQ4NzPAHg-XSY0J04pNm7KqTkgtxyrqOANJLIjXlR-U9SQ90NjHVQ", "e": "AQAB"}, + {"kty": "RSA", "kid": "bilbo.baggins@hobbiton.example", "use": "sig", "n": "n4EPtAOCc9AlkeQHPzHStgAbgs7bTZLwUBZdR8_KuKPEHLd4rHVTeT-O-XV2jRojdNhxJWTDvNd7nqQ0VEiZQHz_AJmSCpMaJMRBSFKrKb2wqVwGU_NsYOYL-QtiWN2lbzcEe6XC0dApr5ydQLrHqkHHig3RBordaZ6Aj-oBHqFEHYpPe7Tpe-OfVfHd1E6cS6M1FZcD1NNLYD5lFHpPI9bTwJlsde3uhGqC0ZCuEHg8lhzwOHrtIQbS0FVbb9k3-tVTU4fg_3L_vniUFAKwuCLqKnS2BYwdq_mzSnbLY7h_qixoR7jig3__kRhuaxwUkRz5iaiQkqgc5gHdrNP5zw", "e": "AQAB"} + ] +} diff --git a/tests/clients/keys/rsa_private.pem b/tests/clients/keys/rsa_private.pem new file mode 100644 index 000000000..e8df41052 --- /dev/null +++ b/tests/clients/keys/rsa_private.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEApF1JaMSN8TEsh4N4O/5SpEAVLivJyLH+Cgl3OQBPGgJkt8cg +49oasl+5iJS+VdrILxWM9/JCJyURpUuslX4Eb4eUBtQ0x5BaPa8+S2NLdGTaL7nB +OO8o8n0C5FEUU+qlEip79KE8aqOj+OC44VsIquSmOvWIQD26n3fCVlgwoRBD1gzz +sDOeaSyzpKrZR851Kh6rEmF2qjJ8jt6EkxMsRNACmBomzgA4M1TTsisSUO87444p +e35Z4/n5c735o2fZMrGgMwiJNh7rT8SYxtIkxngioiGnwkxGQxQ4NzPAHg+XSY0J +04pNm7KqTkgtxyrqOANJLIjXlR+U9SQ90NjHVQIDAQABAoIBABuBPOKaWcJt3yzC +NGGduoif7KtwSnEaUA+v69KPGa2Zju8uFHPssKD+4dZYRc2qMeunKJLpaGaSjnRh +yHyvvOBJCN1nr3lhz6gY5kzJTfwpUFXCOPJlGy4Q+2Xnp4YvcvYqQ9n5DVovDiZ8 +vJOBn16xqpudMPLHIa7D5LJ8SY76HBjE+imTXw1EShdh5TOV9bmPFQqH6JFzowRH +hyH2DPHuyHJj6cl8FyqJw5lVWzG3n6Prvk7bYHsjmGjurN35UsumNAp6VouNyUP1 +RAEcUJega49aIs6/FJ0ENJzQjlsAzVbTleHkpez2aIok+wsWJGJ4SVxAjADOWAaZ +uEJPc3UCgYEA1g4ZGrXOuo75p9/MRIepXGpBWxip4V7B9XmO9WzPCv8nMorJntWB +msYV1I01aITxadHatO4Gl2xLniNkDyrEQzJ7w38RQgsVK+CqbnC0K9N77QPbHeC1 +YQd9RCNyUohOimKvb7jyv798FBU1GO5QI2eNgfnnfteSVXhD2iOoTOsCgYEAxJJ+ +8toxJdnLa0uUsAbql6zeNXGbUBMzu3FomKlyuWuq841jS2kIalaO/TRj5hbnE45j +mCjeLgTVO6Ach3Wfk4zrqajqfFJ0zUg/Wexp49lC3RWiV4icBb85Q6bzeJD9Dn9v +hjpfWVkczf/NeA1fGH/pcgfkT6Dm706GFFttLL8CgYBl/HeXk1H47xAiHO4dJKnb +v0B+X8To/RXamF01r+8BpUoOubOQetdyX7ic+d6deuHu8i6LD/GSCeYJZYFR/KVg +AtiW757QYalnq3ZogkhFrVCZP8IRfTPOFBxp752TlyAcrSI7T9pQ47IBe4094KXM +CJWSfPgAJkOxd0iU0XJpmwKBgGfQxuMTgSlwYRKFlD1zKap5TdID8fbUbVnth0Q5 +GbH7vwlp/qrxCdS/aj0n0irOpbOaW9ccnlrHiqY25VpVMLYIkt3DrDOEiNNx+KNR +TItdTwbcSiTYrS4L0/56ydM/H6bsfsXxRjI18hSJqMZiqXqS84OZz2aOn+h7HCzc +LEiZAoGASk20wFvilpRKHq79xxFWiDUPHi0x0pp82dYIEntGQkKUWkbSlhgf3MAi +5NEQTDmXdnB+rVeWIvEi+BXfdnNgdn8eC4zSdtF4sIAhYr5VWZo0WVWDhT7u2ccv +ZBFymiz8lo3gN57wGUCi9pbZqzV1+ZppX6YTNDdDCE0q+KO3Cec= +-----END RSA PRIVATE KEY----- diff --git a/tests/core/test_requests_client/__init__.py b/tests/clients/test_django/__init__.py similarity index 100% rename from tests/core/test_requests_client/__init__.py rename to tests/clients/test_django/__init__.py diff --git a/tests/clients/test_django/conftest.py b/tests/clients/test_django/conftest.py new file mode 100644 index 000000000..2fbab8774 --- /dev/null +++ b/tests/clients/test_django/conftest.py @@ -0,0 +1,8 @@ +import pytest + +from tests.django_helper import RequestClient + + +@pytest.fixture +def factory(): + return RequestClient() diff --git a/tests/clients/test_django/test_oauth_client.py b/tests/clients/test_django/test_oauth_client.py new file mode 100644 index 000000000..550135471 --- /dev/null +++ b/tests/clients/test_django/test_oauth_client.py @@ -0,0 +1,514 @@ +import time +from unittest import mock + +import pytest +from django.test import override_settings +from joserfc import jwk +from joserfc import jwt + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.integrations.django_client import OAuth +from authlib.integrations.django_client import OAuthError +from authlib.oidc.core.grants.util import create_half_hash + +from ..util import get_bearer_token +from ..util import mock_send_value + +dev_client = {"client_id": "dev-key", "client_secret": "dev-secret"} + + +def test_register_remote_app(): + oauth = OAuth() + with pytest.raises(AttributeError): + oauth.dev # noqa:B018 + + oauth.register( + "dev", + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + +def test_register_with_overwrite(): + oauth = OAuth() + oauth.register( + "dev_overwrite", + overwrite=True, + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + access_token_params={"foo": "foo"}, + authorize_url="https://provider.test/authorize", + ) + assert oauth.dev_overwrite.client_id == "dev-client-id" + assert oauth.dev_overwrite.access_token_params["foo"] == "foo-1" + + +@override_settings(AUTHLIB_OAUTH_CLIENTS={"dev": dev_client}) +def test_register_from_settings(): + oauth = OAuth() + oauth.register("dev") + assert oauth.dev.client_id == "dev-key" + assert oauth.dev.client_secret == "dev-secret" + + +def test_oauth1_authorize(factory): + request = factory.get("/login") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value("oauth_token=foo&oauth_verifier=baz") + + resp = client.authorize_redirect(request) + assert resp.status_code == 302 + url = resp.get("Location") + assert "oauth_token=foo" in url + + request2 = factory.get(url) + request2.session = request.session + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value("oauth_token=a&oauth_token_secret=b") + token = client.authorize_access_token(request2) + assert token["oauth_token"] == "a" + + +def test_oauth2_authorize(factory): + request = factory.get("/login") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + rv = client.authorize_redirect(request, "https://client.test/callback") + assert rv.status_code == 302 + url = rv.get("Location") + assert "state=" in url + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + request2 = factory.get(f"/authorize?state={state}&code=foo") + request2.session = request.session + + token = client.authorize_access_token(request2) + assert token["access_token"] == "a" + + +def test_oauth2_authorize_access_denied(factory): + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with mock.patch("requests.sessions.Session.send"): + request = factory.get("/?error=access_denied&error_description=Not+Allowed") + request.session = factory.session + with pytest.raises(OAuthError): + client.authorize_access_token(request) + + +def test_oauth2_authorize_code_challenge(factory): + request = factory.get("/login") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={"code_challenge_method": "S256"}, + ) + rv = client.authorize_redirect(request, "https://client.test/callback") + assert rv.status_code == 302 + url = rv.get("Location") + assert "state=" in url + assert "code_challenge=" in url + + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + state_data = request.session[f"_state_dev_{state}"]["data"] + verifier = state_data["code_verifier"] + + def fake_send(sess, req, **kwargs): + assert f"code_verifier={verifier}" in req.body + return mock_send_value(get_bearer_token()) + + with mock.patch("requests.sessions.Session.send", fake_send): + request2 = factory.get(f"/authorize?state={state}&code=foo") + request2.session = request.session + token = client.authorize_access_token(request2) + assert token["access_token"] == "a" + + +def test_oauth2_authorize_code_verifier(factory): + request = factory.get("/login") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={"code_challenge_method": "S256"}, + ) + state = "foo" + code_verifier = "bar" + rv = client.authorize_redirect( + request, + "https://client.test/callback", + state=state, + code_verifier=code_verifier, + ) + assert rv.status_code == 302 + url = rv.get("Location") + assert "state=" in url + assert "code_challenge=" in url + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + + request2 = factory.get(f"/authorize?state={state}&code=foo") + request2.session = request.session + + token = client.authorize_access_token(request2) + assert token["access_token"] == "a" + + +def test_openid_authorize(factory): + request = factory.get("/login") + request.session = factory.session + secret_key = jwk.import_key("secret", "oct") + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + jwks={"keys": [secret_key.as_dict()]}, + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={"scope": "openid profile"}, + ) + + resp = client.authorize_redirect(request, "https://client.test/callback") + assert resp.status_code == 302 + url = resp.get("Location") + assert "nonce=" in url + query_data = dict(url_decode(urlparse.urlparse(url).query)) + + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "nonce": query_data["nonce"], + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256"}, claims, secret_key) + token["id_token"] = id_token + state = query_data["state"] + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(token) + + request2 = factory.get(f"/authorize?state={state}&code=foo") + request2.session = request.session + + token = client.authorize_access_token(request2) + assert token["access_token"] == "a" + assert "userinfo" in token + assert token["userinfo"]["sub"] == "123" + + +def test_oauth2_access_token_with_post(factory): + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + payload = {"code": "a", "state": "b"} + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + request = factory.post("/token", data=payload) + request.session = factory.session + request.session["_state_dev_b"] = {"data": {}} + token = client.authorize_access_token(request) + assert token["access_token"] == "a" + + +def test_with_fetch_token_in_oauth(factory): + def fetch_token(name, request): + return {"access_token": name, "token_type": "bearer"} + + oauth = OAuth(fetch_token=fetch_token) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + def fake_send(sess, req, **kwargs): + assert sess.token["access_token"] == "dev" + return mock_send_value(get_bearer_token()) + + with mock.patch("requests.sessions.Session.send", fake_send): + request = factory.get("/login") + client.get("/user", request=request) + + +def test_with_fetch_token_in_register(factory): + def fetch_token(request): + return {"access_token": "dev", "token_type": "bearer"} + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + fetch_token=fetch_token, + ) + + def fake_send(sess, req, **kwargs): + assert sess.token["access_token"] == "dev" + return mock_send_value(get_bearer_token()) + + with mock.patch("requests.sessions.Session.send", fake_send): + request = factory.get("/login") + client.get("/user", request=request) + + +def test_request_without_token(): + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + def fake_send(sess, req, **kwargs): + auth = req.headers.get("Authorization") + assert auth is None + resp = mock.MagicMock() + resp.text = "hi" + resp.status_code = 200 + return resp + + with mock.patch("requests.sessions.Session.send", fake_send): + resp = client.get("/api/user", withhold_token=True) + assert resp.text == "hi" + with pytest.raises(OAuthError): + client.get("https://resource.test/api/user") + + +def test_logout_redirect(factory): + """Test logout_redirect generates correct URL with state stored in session.""" + request = factory.get("/logout") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + resp = client.logout_redirect( + request, + post_logout_redirect_uri="https://client.test/logged-out", + id_token_hint="fake.id.token", + ) + assert resp.status_code == 302 + url = resp.get("Location") + assert "https://provider.test/logout" in url + assert "id_token_hint=fake.id.token" in url + assert "post_logout_redirect_uri" in url + assert "state=" in url + + # Verify state is stored in session + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + assert f"_state_dev_{state}" in request.session + + +def test_logout_redirect_without_redirect_uri(factory): + """Test logout_redirect omits state when no post_logout_redirect_uri is provided.""" + request = factory.get("/logout") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + resp = client.logout_redirect(request, id_token_hint="fake.id.token") + assert resp.status_code == 302 + url = resp.get("Location") + assert "id_token_hint=fake.id.token" in url + assert "state" not in url + + +def test_logout_redirect_missing_endpoint(factory): + """Test logout_redirect raises RuntimeError when end_session_endpoint is missing.""" + request = factory.get("/logout") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "authorization_endpoint": "https://provider.test/authorize", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with pytest.raises(RuntimeError, match='Missing "end_session_endpoint"'): + client.logout_redirect(request) + + +def test_validate_logout_response(factory): + """Test validate_logout_response verifies state and returns stored data.""" + request = factory.get("/logout") + request.session = factory.session + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + resp = client.logout_redirect( + request, + post_logout_redirect_uri="https://client.test/logged-out", + ) + url = resp.get("Location") + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + + request2 = factory.get(f"/logged-out?state={state}") + request2.session = request.session + state_data = client.validate_logout_response(request2) + assert ( + state_data["post_logout_redirect_uri"] == "https://client.test/logged-out" + ) + # State should be cleared from session + assert f"_state_dev_{state}" not in request2.session + + +def test_validate_logout_response_missing_state(factory): + """Test validate_logout_response raises OAuthError when state is missing.""" + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + request = factory.get("/logged-out") + request.session = factory.session + with pytest.raises(OAuthError, match='Missing "state" parameter'): + client.validate_logout_response(request) + + +def test_validate_logout_response_invalid_state(factory): + """Test validate_logout_response raises OAuthError when state is invalid.""" + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + request = factory.get("/logged-out?state=invalid-state") + request.session = factory.session + with pytest.raises(OAuthError, match='Invalid "state" parameter'): + client.validate_logout_response(request) diff --git a/tests/django/test_client/__init__.py b/tests/clients/test_flask/__init__.py similarity index 100% rename from tests/django/test_client/__init__.py rename to tests/clients/test_flask/__init__.py diff --git a/tests/clients/test_flask/test_oauth_client.py b/tests/clients/test_flask/test_oauth_client.py new file mode 100644 index 000000000..b30eebef4 --- /dev/null +++ b/tests/clients/test_flask/test_oauth_client.py @@ -0,0 +1,916 @@ +import time +from unittest import mock + +import pytest +from cachelib import SimpleCache +from flask import Flask +from flask import session +from joserfc import jwk +from joserfc import jwt + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.integrations.flask_client import FlaskOAuth2App +from authlib.integrations.flask_client import OAuth +from authlib.integrations.flask_client import OAuthError +from authlib.oauth2.rfc6749.errors import MissingCodeException +from authlib.oidc.core.grants.util import create_half_hash + +from ..util import get_bearer_token +from ..util import mock_send_value + + +def test_register_remote_app(): + app = Flask(__name__) + oauth = OAuth(app) + with pytest.raises(AttributeError): + oauth.dev # noqa:B018 + + oauth.register( + "dev", + client_id="dev", + client_secret="dev", + ) + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + +def test_register_conf_from_app(): + app = Flask(__name__) + app.config.update( + { + "DEV_CLIENT_ID": "dev", + "DEV_CLIENT_SECRET": "dev", + } + ) + oauth = OAuth(app) + oauth.register("dev") + assert oauth.dev.client_id == "dev" + + +def test_register_with_overwrite(): + app = Flask(__name__) + app.config.update( + { + "DEV_CLIENT_ID": "dev-1", + "DEV_CLIENT_SECRET": "dev", + "DEV_ACCESS_TOKEN_PARAMS": {"foo": "foo-1"}, + } + ) + oauth = OAuth(app) + oauth.register( + "dev", overwrite=True, client_id="dev", access_token_params={"foo": "foo"} + ) + assert oauth.dev.client_id == "dev-1" + assert oauth.dev.client_secret == "dev" + assert oauth.dev.access_token_params["foo"] == "foo-1" + + +def test_init_app_later(): + app = Flask(__name__) + app.config.update( + { + "DEV_CLIENT_ID": "dev", + "DEV_CLIENT_SECRET": "dev", + } + ) + oauth = OAuth() + remote = oauth.register("dev") + with pytest.raises(RuntimeError): + oauth.dev.client_id # noqa:B018 + oauth.init_app(app) + assert oauth.dev.client_id == "dev" + assert remote.client_id == "dev" + + assert oauth.cache is None + assert oauth.fetch_token is None + assert oauth.update_token is None + + +def test_init_app_params(): + app = Flask(__name__) + oauth = OAuth() + oauth.init_app(app, SimpleCache()) + assert oauth.cache is not None + assert oauth.update_token is None + + oauth.init_app(app, update_token=lambda o: o) + assert oauth.update_token is not None + + +def test_create_client(): + app = Flask(__name__) + oauth = OAuth(app) + assert oauth.create_client("dev") is None + oauth.register("dev", client_id="dev") + assert oauth.create_client("dev") is not None + + +def test_register_oauth1_remote_app(): + app = Flask(__name__) + oauth = OAuth(app) + client_kwargs = dict( + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + fetch_request_token=lambda: None, + save_request_token=lambda token: token, + ) + oauth.register("dev", **client_kwargs) + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + oauth = OAuth(app, cache=SimpleCache()) + oauth.register("dev", **client_kwargs) + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + +def test_oauth1_authorize_cache(): + app = Flask(__name__) + app.secret_key = "!" + cache = SimpleCache() + oauth = OAuth(app, cache=cache) + + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with app.test_request_context(): + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value("oauth_token=foo&oauth_verifier=baz") + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "oauth_token=foo" in url + session_data = session["_state_dev_foo"] + assert "exp" in session_data + assert "data" not in session_data + + with app.test_request_context("/?oauth_token=foo"): + with mock.patch("requests.sessions.Session.send") as send: + session["_state_dev_foo"] = session_data + send.return_value = mock_send_value("oauth_token=a&oauth_token_secret=b") + token = client.authorize_access_token() + assert token["oauth_token"] == "a" + + +def test_oauth1_authorize_session(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with app.test_request_context(): + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value("oauth_token=foo&oauth_verifier=baz") + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "oauth_token=foo" in url + data = session["_state_dev_foo"] + + with app.test_request_context("/?oauth_token=foo"): + session["_state_dev_foo"] = data + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value("oauth_token=a&oauth_token_secret=b") + token = client.authorize_access_token() + assert token["oauth_token"] == "a" + + +def test_register_oauth2_remote_app(): + app = Flask(__name__) + oauth = OAuth(app) + oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + refresh_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + update_token=lambda name: "hi", + ) + assert oauth.dev.name == "dev" + session = oauth.dev._get_oauth_client() + assert session.update_token is not None + + +def test_oauth2_authorize_cache(): + app = Flask(__name__) + app.secret_key = "!" + cache = SimpleCache() + oauth = OAuth(app, cache=cache) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "state=" in url + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + assert state is not None + session_data = session[f"_state_dev_{state}"] + assert "exp" in session_data + assert "data" not in session_data + + with app.test_request_context(path=f"/?code=a&state={state}"): + # session is cleared in tests + session[f"_state_dev_{state}"] = session_data + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + token = client.authorize_access_token() + assert token["access_token"] == "a" + + with app.test_request_context(): + assert client.token is None + + +def test_oauth2_authorize_session(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "state=" in url + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + assert state is not None + session_data = session[f"_state_dev_{state}"] + assert "exp" in session_data + assert "data" in session_data + + with app.test_request_context(path=f"/?code=a&state={state}"): + # session is cleared in tests + session[f"_state_dev_{state}"] = session_data + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + token = client.authorize_access_token() + assert token["access_token"] == "a" + + with app.test_request_context(): + assert client.token is None + + +def test_oauth2_authorize_access_denied(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with app.test_request_context( + path="/?error=access_denied&error_description=Not+Allowed" + ): + # session is cleared in tests + with mock.patch("requests.sessions.Session.send"): + with pytest.raises(OAuthError): + client.authorize_access_token() + + +def test_oauth2_authorize_via_custom_client(): + class CustomRemoteApp(FlaskOAuth2App): + OAUTH_APP_CONFIG = {"authorize_url": "https://provider.test/custom"} + + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + client_cls=CustomRemoteApp, + ) + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert url.startswith("https://provider.test/custom?") + + +def test_oauth2_fetch_metadata(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + with mock.patch("requests.sessions.Session.send") as send: + + def check_request(req, **kwargs): + assert "Authlib/" in req.headers.get("user-agent", "") + if req.url == "https://provider.test/.well-known/openid-configuration": + return mock_send_value( + { + "authorization_endpoint": "https://provider.test/authorize", + "jwks_uri": "https://provider.test/.well-known/keys", + } + ) + if req.url == "https://provider.test/.well-known/keys": + return mock_send_value({"keys": []}) + return mock.DEFAULT + + send.side_effect = check_request + + with app.test_request_context(): + client.fetch_jwk_set() + + assert send.call_count == 2 + + +def test_oauth2_authorize_with_metadata(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + ) + with pytest.raises(RuntimeError): + client.create_authorization_url(None) + + client = oauth.register( + "dev2", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value( + {"authorization_endpoint": "https://provider.test/authorize"} + ) + + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + + +def test_oauth2_authorize_code_challenge(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={"code_challenge_method": "S256"}, + ) + + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + assert state is not None + data = session[f"_state_dev_{state}"] + + verifier = data["data"]["code_verifier"] + assert verifier is not None + + def fake_send(sess, req, **kwargs): + assert f"code_verifier={verifier}" in req.body + return mock_send_value(get_bearer_token()) + + path = f"/?code=a&state={state}" + with app.test_request_context(path=path): + # session is cleared in tests + session[f"_state_dev_{state}"] = data + + with mock.patch("requests.sessions.Session.send", fake_send): + token = client.authorize_access_token() + assert token["access_token"] == "a" + + +def test_openid_authorize(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + key = jwk.import_key("secret", "oct") + + client = oauth.register( + "dev", + client_id="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={"scope": "openid profile"}, + jwks={"keys": [key.as_dict()]}, + ) + + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + + url = resp.headers["Location"] + query_data = dict(url_decode(urlparse.urlparse(url).query)) + + state = query_data["state"] + assert state is not None + session_data = session[f"_state_dev_{state}"] + nonce = session_data["data"]["nonce"] + assert nonce is not None + assert nonce == query_data["nonce"] + + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "nonce": query_data["nonce"], + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256"}, claims, key) + token["id_token"] = id_token + path = f"/?code=a&state={state}" + with app.test_request_context(path=path): + session[f"_state_dev_{state}"] = session_data + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(token) + token = client.authorize_access_token() + assert token["access_token"] == "a" + assert "userinfo" in token + + +def test_oauth2_access_token_with_post(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + payload = {"code": "a", "state": "b"} + with app.test_request_context(data=payload, method="POST"): + session["_state_dev_b"] = {"data": payload} + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + token = client.authorize_access_token() + assert token["access_token"] == "a" + + +def test_access_token_with_fetch_token(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth() + + token = get_bearer_token() + oauth.init_app(app, fetch_token=lambda name: token) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + def fake_send(sess, req, **kwargs): + auth = req.headers["Authorization"] + assert auth == "Bearer {}".format(token["access_token"]) + resp = mock.MagicMock() + resp.text = "hi" + resp.status_code = 200 + return resp + + with app.test_request_context(): + with mock.patch("requests.sessions.Session.send", fake_send): + resp = client.get("/api/user") + assert resp.text == "hi" + + # trigger ctx.authlib_client_oauth_token + resp = client.get("/api/user") + assert resp.text == "hi" + + +def test_request_with_refresh_token(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth() + + expired_token = { + "token_type": "Bearer", + "access_token": "expired-a", + "refresh_token": "expired-b", + "expires_in": "3600", + "expires_at": 1566465749, + } + oauth.init_app(app, fetch_token=lambda name: expired_token) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + refresh_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + def fake_send(sess, req, **kwargs): + if req.url == "https://provider.test/token": + auth = req.headers["Authorization"] + assert "Basic" in auth + resp = mock.MagicMock() + resp.json = get_bearer_token + resp.status_code = 200 + return resp + + resp = mock.MagicMock() + resp.text = "hi" + resp.status_code = 200 + return resp + + with app.test_request_context(): + with mock.patch("requests.sessions.Session.send", fake_send): + resp = client.get("/api/user", token=expired_token) + assert resp.text == "hi" + + +def test_request_without_token(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + def fake_send(sess, req, **kwargs): + auth = req.headers.get("Authorization") + assert auth is None + resp = mock.MagicMock() + resp.text = "hi" + resp.status_code = 200 + return resp + + with app.test_request_context(): + with mock.patch("requests.sessions.Session.send", fake_send): + resp = client.get("/api/user", withhold_token=True) + assert resp.text == "hi" + with pytest.raises(OAuthError): + client.get("https://resource.test/api/user") + + +def test_oauth2_authorize_missing_code(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + state = dict(url_decode(urlparse.urlparse(resp.headers["Location"]).query))[ + "state" + ] + session_data = session[f"_state_dev_{state}"] + + # Test missing code parameter + with app.test_request_context(path=f"/?state={state}"): + session[f"_state_dev_{state}"] = session_data + with pytest.raises(MissingCodeException) as exc_info: + client.authorize_access_token() + assert exc_info.value.error == "missing_code" + + +def test_logout_redirect_with_metadata(): + """Test logout_redirect generates correct URL with state stored in session.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + resp = client.logout_redirect( + post_logout_redirect_uri="https://client.test/logged-out", + id_token_hint="fake.id.token", + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "https://provider.test/logout" in url + assert "id_token_hint=fake.id.token" in url + assert ( + "post_logout_redirect_uri=https%3A%2F%2Fclient.test%2Flogged-out" in url + ) + assert "state=" in url + + # Verify state is stored in session + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + assert f"_state_dev_{state}" in session + + +def test_logout_redirect_without_redirect_uri(): + """Test logout_redirect omits state when no post_logout_redirect_uri is provided.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + resp = client.logout_redirect(id_token_hint="fake.id.token") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "https://provider.test/logout" in url + assert "id_token_hint=fake.id.token" in url + assert "post_logout_redirect_uri" not in url + assert "state" not in url + + # No state stored when no redirect_uri + assert not any(k.startswith("_state_dev_") for k in session.keys()) + + +def test_logout_redirect_with_extra_params(): + """Test logout_redirect includes optional params: client_id, logout_hint, ui_locales.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + resp = client.logout_redirect( + post_logout_redirect_uri="https://client.test/logged-out", + client_id="dev", + logout_hint="user@example.com", + ui_locales="fr", + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "client_id=dev" in url + assert "logout_hint=user%40example.com" in url + assert "ui_locales=fr" in url + + +def test_logout_redirect_with_custom_state(): + """Test logout_redirect uses a custom state value when provided.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + resp = client.logout_redirect( + post_logout_redirect_uri="https://client.test/logged-out", + state="custom-state-123", + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "state=custom-state-123" in url + + +def test_logout_redirect_missing_endpoint(): + """Test logout_redirect raises RuntimeError when end_session_endpoint is missing.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + # Metadata without end_session_endpoint + metadata = { + "issuer": "https://provider.test", + "authorization_endpoint": "https://provider.test/authorize", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + with pytest.raises(RuntimeError, match='Missing "end_session_endpoint"'): + client.logout_redirect() + + +def test_create_logout_url_directly(): + """Test create_logout_url returns URL and state without performing redirect.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + result = client.create_logout_url( + post_logout_redirect_uri="https://client.test/logged-out", + id_token_hint="fake.id.token", + ) + assert "url" in result + assert "state" in result + assert result["state"] is not None + assert "https://provider.test/logout" in result["url"] + + +def test_validate_logout_response(): + """Test validate_logout_response verifies state and returns stored data.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + metadata = { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(metadata) + + with app.test_request_context(): + resp = client.logout_redirect( + post_logout_redirect_uri="https://client.test/logged-out", + ) + url = resp.headers.get("Location") + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + session_data = session[f"_state_dev_{state}"] + + with app.test_request_context(path=f"/?state={state}"): + session[f"_state_dev_{state}"] = session_data + state_data = client.validate_logout_response() + assert ( + state_data["post_logout_redirect_uri"] + == "https://client.test/logged-out" + ) + # State should be cleared from session + assert f"_state_dev_{state}" not in session + + +def test_validate_logout_response_missing_state(): + """Test validate_logout_response raises OAuthError when state is missing.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + with app.test_request_context(path="/"): + with pytest.raises(OAuthError, match='Missing "state" parameter'): + client.validate_logout_response() + + +def test_validate_logout_response_invalid_state(): + """Test validate_logout_response raises OAuthError when state is invalid.""" + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + with app.test_request_context(path="/?state=invalid-state"): + with pytest.raises(OAuthError, match='Invalid "state" parameter'): + client.validate_logout_response() diff --git a/tests/clients/test_flask/test_user_mixin.py b/tests/clients/test_flask/test_user_mixin.py new file mode 100644 index 000000000..da07872dd --- /dev/null +++ b/tests/clients/test_flask/test_user_mixin.py @@ -0,0 +1,197 @@ +import time +from unittest import mock + +import pytest +from flask import Flask +from joserfc import jwt +from joserfc.errors import InvalidClaimError +from joserfc.jwk import KeySet +from joserfc.jwk import OctKey + +from authlib.integrations.flask_client import OAuth +from authlib.oidc.core.grants.util import create_half_hash + +from ..util import get_bearer_token +from ..util import read_key_file + +secret_key = OctKey.import_key("test-oct-secret", {"kty": "oct", "kid": "f"}) + + +def test_fetch_userinfo(): + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + userinfo_endpoint="https://provider.test/userinfo", + ) + + def fake_send(sess, req, **kwargs): + resp = mock.MagicMock() + resp.json = lambda: {"sub": "123"} + resp.status_code = 200 + return resp + + with app.test_request_context(): + with mock.patch("requests.sessions.Session.send", fake_send): + user = client.userinfo() + assert user.sub == "123" + + +def test_parse_id_token(): + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "nonce": "n", + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256"}, claims, secret_key) + + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + jwks={"keys": [secret_key.as_dict()]}, + issuer="https://provider.test", + id_token_signing_alg_values_supported=["HS256", "RS256"], + ) + with app.test_request_context(): + assert client.parse_id_token(token, nonce="n") is None + + token["id_token"] = id_token + user = client.parse_id_token(token, nonce="n") + assert user.sub == "123" + + claims_options = {"iss": {"value": "https://provider.test"}} + user = client.parse_id_token(token, nonce="n", claims_options=claims_options) + assert user.sub == "123" + + claims_options = {"iss": {"value": "https://wrong-provider.test"}} + with pytest.raises(InvalidClaimError): + client.parse_id_token(token, "n", claims_options) + + +def test_parse_id_token_nonce_supported(): + token = get_bearer_token() + + now = int(time.time()) + claims = { + "sub": "123", + "nonce_supported": False, + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256"}, claims, secret_key) + + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + jwks={"keys": [secret_key.as_dict()]}, + issuer="https://provider.test", + id_token_signing_alg_values_supported=["HS256", "RS256"], + ) + with app.test_request_context(): + token["id_token"] = id_token + user = client.parse_id_token(token, nonce="n") + assert user.sub == "123" + + +def test_runtime_error_fetch_jwks_uri(): + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "nonce": "n", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256", "kid": "not-found"}, claims, secret_key) + + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + alt_key = secret_key.as_dict() + alt_key["kid"] = "b" + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + jwks={"keys": [alt_key]}, + issuer="https://provider.test", + id_token_signing_alg_values_supported=["HS256"], + ) + with app.test_request_context(): + token["id_token"] = id_token + with pytest.raises(RuntimeError): + client.parse_id_token(token, "n") + + +def test_force_fetch_jwks_uri(): + secret_keys = KeySet.import_key_set(read_key_file("jwks_private.json")) + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "nonce": "n", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "at_hash": create_half_hash(token["access_token"], "RS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "RS256"}, claims, secret_keys) + + app = Flask(__name__) + app.secret_key = "!" + oauth = OAuth(app) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + jwks={"keys": [secret_key.as_dict()]}, + jwks_uri="https://provider.test/jwks", + issuer="https://provider.test", + ) + + def fake_send(sess, req, **kwargs): + resp = mock.MagicMock() + resp.json = lambda: read_key_file("jwks_public.json") + resp.status_code = 200 + return resp + + with app.test_request_context(): + assert client.parse_id_token(token, nonce="n") is None + + with mock.patch("requests.sessions.Session.send", fake_send): + token["id_token"] = id_token + user = client.parse_id_token(token, nonce="n") + assert user.sub == "123" diff --git a/tests/flask/test_client/__init__.py b/tests/clients/test_httpx/__init__.py similarity index 100% rename from tests/flask/test_client/__init__.py rename to tests/clients/test_httpx/__init__.py diff --git a/tests/clients/test_httpx/test_assertion_client.py b/tests/clients/test_httpx/test_assertion_client.py new file mode 100644 index 000000000..d6a980c8f --- /dev/null +++ b/tests/clients/test_httpx/test_assertion_client.py @@ -0,0 +1,65 @@ +import time + +import pytest +from httpx import WSGITransport + +from authlib.integrations.httpx_client import AssertionClient + +from ..wsgi_helper import MockDispatch + +default_token = { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, +} + + +def test_refresh_token(): + def verifier(request): + content = request.form + if str(request.url) == "https://provider.test/token": + assert "assertion" in content + + with AssertionClient( + "https://provider.test/token", + issuer="foo", + subject="foo", + audience="foo", + alg="HS256", + key="secret", + transport=WSGITransport(MockDispatch(default_token, assert_func=verifier)), + ) as client: + client.get("https://provider.test") + + # trigger more case + now = int(time.time()) + with AssertionClient( + "https://provider.test/token", + issuer="foo", + subject=None, + audience="foo", + issued_at=now, + expires_at=now + 3600, + header={"alg": "HS256"}, + key="secret", + scope="email", + claims={"test_mode": "true"}, + transport=WSGITransport(MockDispatch(default_token, assert_func=verifier)), + ) as client: + client.get("https://provider.test") + client.get("https://provider.test") + + +def test_without_alg(): + with AssertionClient( + "https://provider.test/token", + issuer="foo", + subject="foo", + audience="foo", + key="secret", + transport=WSGITransport(MockDispatch(default_token)), + ) as client: + with pytest.raises(ValueError): + client.get("https://provider.test") diff --git a/tests/clients/test_httpx/test_async_assertion_client.py b/tests/clients/test_httpx/test_async_assertion_client.py new file mode 100644 index 000000000..289d077e7 --- /dev/null +++ b/tests/clients/test_httpx/test_async_assertion_client.py @@ -0,0 +1,68 @@ +import time + +import pytest +from httpx import ASGITransport + +from authlib.integrations.httpx_client import AsyncAssertionClient + +from ..asgi_helper import AsyncMockDispatch + +default_token = { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, +} + + +@pytest.mark.asyncio +async def test_refresh_token(): + async def verifier(request): + content = await request.body() + if str(request.url) == "https://provider.test/token": + assert b"assertion=" in content + + async with AsyncAssertionClient( + "https://provider.test/token", + grant_type=AsyncAssertionClient.JWT_BEARER_GRANT_TYPE, + issuer="foo", + subject="foo", + audience="foo", + alg="HS256", + key="secret", + transport=ASGITransport(AsyncMockDispatch(default_token, assert_func=verifier)), + ) as client: + await client.get("https://provider.test") + + # trigger more case + now = int(time.time()) + async with AsyncAssertionClient( + "https://provider.test/token", + issuer="foo", + subject=None, + audience="foo", + issued_at=now, + expires_at=now + 3600, + header={"alg": "HS256"}, + key="secret", + scope="email", + claims={"test_mode": "true"}, + transport=ASGITransport(AsyncMockDispatch(default_token, assert_func=verifier)), + ) as client: + await client.get("https://provider.test") + await client.get("https://provider.test") + + +@pytest.mark.asyncio +async def test_without_alg(): + async with AsyncAssertionClient( + "https://provider.test/token", + issuer="foo", + subject="foo", + audience="foo", + key="secret", + transport=ASGITransport(AsyncMockDispatch()), + ) as client: + with pytest.raises(ValueError): + await client.get("https://provider.test") diff --git a/tests/clients/test_httpx/test_async_oauth1_client.py b/tests/clients/test_httpx/test_async_oauth1_client.py new file mode 100644 index 000000000..d469d8329 --- /dev/null +++ b/tests/clients/test_httpx/test_async_oauth1_client.py @@ -0,0 +1,174 @@ +import pytest +from httpx import ASGITransport + +from authlib.integrations.httpx_client import SIGNATURE_TYPE_BODY +from authlib.integrations.httpx_client import SIGNATURE_TYPE_QUERY +from authlib.integrations.httpx_client import AsyncOAuth1Client +from authlib.integrations.httpx_client import OAuthError + +from ..asgi_helper import AsyncMockDispatch + +oauth_url = "https://provider.test/oauth" + + +@pytest.mark.asyncio +async def test_fetch_request_token_via_header(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + async def assert_func(request): + auth_header = request.headers.get("authorization") + assert 'oauth_consumer_key="id"' in auth_header + assert "oauth_signature=" in auth_header + + transport = ASGITransport(AsyncMockDispatch(request_token, assert_func=assert_func)) + async with AsyncOAuth1Client("id", "secret", transport=transport) as client: + response = await client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +async def test_fetch_request_token_via_body(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + async def assert_func(request): + auth_header = request.headers.get("authorization") + assert auth_header is None + + content = await request.body() + assert b"oauth_consumer_key=id" in content + assert b"&oauth_signature=" in content + + transport = ASGITransport(AsyncMockDispatch(request_token, assert_func=assert_func)) + + async with AsyncOAuth1Client( + "id", + "secret", + signature_type=SIGNATURE_TYPE_BODY, + transport=transport, + ) as client: + response = await client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +async def test_fetch_request_token_via_query(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + async def assert_func(request): + auth_header = request.headers.get("authorization") + assert auth_header is None + + url = str(request.url) + assert "oauth_consumer_key=id" in url + assert "&oauth_signature=" in url + + transport = ASGITransport(AsyncMockDispatch(request_token, assert_func=assert_func)) + + async with AsyncOAuth1Client( + "id", + "secret", + signature_type=SIGNATURE_TYPE_QUERY, + transport=transport, + ) as client: + response = await client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +async def test_fetch_access_token(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + async def assert_func(request): + auth_header = request.headers.get("authorization") + assert 'oauth_verifier="d"' in auth_header + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert "oauth_signature=" in auth_header + + transport = ASGITransport(AsyncMockDispatch(request_token, assert_func=assert_func)) + async with AsyncOAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + transport=transport, + ) as client: + with pytest.raises(OAuthError): + await client.fetch_access_token(oauth_url) + + response = await client.fetch_access_token(oauth_url, verifier="d") + + assert response == request_token + + +@pytest.mark.asyncio +async def test_get_via_header(): + transport = ASGITransport(AsyncMockDispatch(b"hello")) + async with AsyncOAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + transport=transport, + ) as client: + response = await client.get("https://resource.test/") + + assert response.content == b"hello" + request = response.request + auth_header = request.headers.get("authorization") + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert "oauth_signature=" in auth_header + + +@pytest.mark.asyncio +async def test_get_via_body(): + async def assert_func(request): + content = await request.body() + assert b"oauth_token=foo" in content + assert b"oauth_consumer_key=id" in content + assert b"oauth_signature=" in content + + transport = ASGITransport(AsyncMockDispatch(b"hello", assert_func=assert_func)) + async with AsyncOAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + signature_type=SIGNATURE_TYPE_BODY, + transport=transport, + ) as client: + response = await client.post("https://resource.test/") + + assert response.content == b"hello" + + request = response.request + auth_header = request.headers.get("authorization") + assert auth_header is None + + +@pytest.mark.asyncio +async def test_get_via_query(): + transport = ASGITransport(AsyncMockDispatch(b"hello")) + async with AsyncOAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + signature_type=SIGNATURE_TYPE_QUERY, + transport=transport, + ) as client: + response = await client.get("https://resource.test/") + + assert response.content == b"hello" + request = response.request + auth_header = request.headers.get("authorization") + assert auth_header is None + + url = str(request.url) + assert "oauth_token=foo" in url + assert "oauth_consumer_key=id" in url + assert "oauth_signature=" in url diff --git a/tests/clients/test_httpx/test_async_oauth2_client.py b/tests/clients/test_httpx/test_async_oauth2_client.py new file mode 100644 index 000000000..6b855815c --- /dev/null +++ b/tests/clients/test_httpx/test_async_oauth2_client.py @@ -0,0 +1,447 @@ +import asyncio +import time +from copy import deepcopy +from unittest import mock + +import pytest +from httpx import ASGITransport +from httpx import AsyncClient + +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from authlib.integrations.httpx_client import AsyncOAuth2Client +from authlib.integrations.httpx_client import OAuthError + +from ..asgi_helper import AsyncMockDispatch + +default_token = { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, +} + + +@pytest.mark.asyncio +async def assert_token_in_header(request): + token = "Bearer " + default_token["access_token"] + auth_header = request.headers.get("authorization") + assert auth_header == token + + +@pytest.mark.asyncio +async def assert_token_in_body(request): + content = await request.body() + assert default_token["access_token"] in content.decode() + + +@pytest.mark.asyncio +async def assert_token_in_uri(request): + assert default_token["access_token"] in str(request.url) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "assert_func, token_placement", + [ + (assert_token_in_header, "header"), + (assert_token_in_body, "body"), + (assert_token_in_uri, "uri"), + ], +) +async def test_add_token_get_request(assert_func, token_placement): + transport = ASGITransport(AsyncMockDispatch({"a": "a"}, assert_func=assert_func)) + async with AsyncOAuth2Client( + "foo", token=default_token, token_placement=token_placement, transport=transport + ) as client: + resp = await client.get("https://provider.test") + + data = resp.json() + assert data["a"] == "a" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "assert_func, token_placement", + [ + (assert_token_in_header, "header"), + (assert_token_in_body, "body"), + (assert_token_in_uri, "uri"), + ], +) +async def test_add_token_to_streaming_request(assert_func, token_placement): + transport = ASGITransport(AsyncMockDispatch({"a": "a"}, assert_func=assert_func)) + async with AsyncOAuth2Client( + "foo", token=default_token, token_placement=token_placement, transport=transport + ) as client: + async with client.stream("GET", "https://provider.test") as stream: + await stream.aread() + data = stream.json() + + assert data["a"] == "a" + + +@pytest.mark.parametrize( + "client", + [ + AsyncOAuth2Client( + "foo", + token=default_token, + token_placement="header", + transport=ASGITransport( + AsyncMockDispatch({"a": "a"}, assert_func=assert_token_in_header) + ), + ), + AsyncClient(transport=ASGITransport(AsyncMockDispatch({"a": "a"}))), + ], +) +async def test_httpx_client_stream_match(client): + async with client as client_entered: + async with client_entered.stream("GET", "https://provider.test") as stream: + assert stream.status_code == 200 + + +def test_create_authorization_url(): + url = "https://provider.test/authorize?foo=bar" + + sess = AsyncOAuth2Client(client_id="foo") + auth_url, state = sess.create_authorization_url(url) + assert state in auth_url + assert "client_id=foo" in auth_url + assert "response_type=code" in auth_url + + sess = AsyncOAuth2Client(client_id="foo", prompt="none") + auth_url, state = sess.create_authorization_url( + url, state="foo", redirect_uri="https://provider.test", scope="profile" + ) + assert state == "foo" + assert "provider.test" in auth_url + assert "profile" in auth_url + assert "prompt=none" in auth_url + + +def test_code_challenge(): + sess = AsyncOAuth2Client("foo", code_challenge_method="S256") + + url = "https://provider.test/authorize" + auth_url, _ = sess.create_authorization_url(url, code_verifier=generate_token(48)) + assert "code_challenge=" in auth_url + assert "code_challenge_method=S256" in auth_url + + +def test_token_from_fragment(): + sess = AsyncOAuth2Client("foo") + response_url = "https://provider.test/callback#" + url_encode(default_token.items()) + assert sess.token_from_fragment(response_url) == default_token + token = sess.fetch_token(authorization_response=response_url) + assert token == default_token + + +@pytest.mark.asyncio +async def test_fetch_token_post(): + url = "https://provider.test/token" + + async def assert_func(request): + content = await request.body() + content = content.decode() + assert "code=v" in content + assert "client_id=" in content + assert "grant_type=authorization_code" in content + + transport = ASGITransport(AsyncMockDispatch(default_token, assert_func=assert_func)) + async with AsyncOAuth2Client("foo", transport=transport) as client: + token = await client.fetch_token( + url, authorization_response="https://provider.test/?code=v" + ) + assert token == default_token + + async with AsyncOAuth2Client( + "foo", token_endpoint_auth_method="none", transport=transport + ) as client: + token = await client.fetch_token(url, code="v") + assert token == default_token + + transport = ASGITransport(AsyncMockDispatch({"error": "invalid_request"})) + async with AsyncOAuth2Client("foo", transport=transport) as client: + with pytest.raises(OAuthError): + await client.fetch_token(url) + + +@pytest.mark.asyncio +async def test_fetch_token_get(): + url = "https://provider.test/token" + + async def assert_func(request): + url = str(request.url) + assert "code=v" in url + assert "client_id=" in url + assert "grant_type=authorization_code" in url + + transport = ASGITransport(AsyncMockDispatch(default_token, assert_func=assert_func)) + async with AsyncOAuth2Client("foo", transport=transport) as client: + authorization_response = "https://provider.test/?code=v" + token = await client.fetch_token( + url, authorization_response=authorization_response, method="GET" + ) + assert token == default_token + + async with AsyncOAuth2Client( + "foo", token_endpoint_auth_method="none", transport=transport + ) as client: + token = await client.fetch_token(url, code="v", method="GET") + assert token == default_token + + token = await client.fetch_token(url + "?q=a", code="v", method="GET") + assert token == default_token + + +@pytest.mark.asyncio +async def test_token_auth_method_client_secret_post(): + url = "https://provider.test/token" + + async def assert_func(request): + content = await request.body() + content = content.decode() + assert "code=v" in content + assert "client_id=" in content + assert "client_secret=bar" in content + assert "grant_type=authorization_code" in content + + transport = ASGITransport(AsyncMockDispatch(default_token, assert_func=assert_func)) + async with AsyncOAuth2Client( + "foo", + "bar", + token_endpoint_auth_method="client_secret_post", + transport=transport, + ) as client: + token = await client.fetch_token(url, code="v") + + assert token == default_token + + +@pytest.mark.asyncio +async def test_access_token_response_hook(): + url = "https://provider.test/token" + + def _access_token_response_hook(resp): + assert resp.json() == default_token + return resp + + access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) + transport = ASGITransport(AsyncMockDispatch(default_token)) + async with AsyncOAuth2Client( + "foo", token=default_token, transport=transport + ) as sess: + sess.register_compliance_hook( + "access_token_response", access_token_response_hook + ) + assert await sess.fetch_token(url) == default_token + assert access_token_response_hook.called is True + + +@pytest.mark.asyncio +async def test_password_grant_type(): + url = "https://provider.test/token" + + async def assert_func(request): + content = await request.body() + content = content.decode() + assert "username=v" in content + assert "scope=profile" in content + assert "grant_type=password" in content + + transport = ASGITransport(AsyncMockDispatch(default_token, assert_func=assert_func)) + async with AsyncOAuth2Client("foo", scope="profile", transport=transport) as sess: + token = await sess.fetch_token(url, username="v", password="v") + assert token == default_token + + token = await sess.fetch_token( + url, username="v", password="v", grant_type="password" + ) + assert token == default_token + + +@pytest.mark.asyncio +async def test_client_credentials_type(): + url = "https://provider.test/token" + + async def assert_func(request): + content = await request.body() + content = content.decode() + assert "scope=profile" in content + assert "grant_type=client_credentials" in content + + transport = ASGITransport(AsyncMockDispatch(default_token, assert_func=assert_func)) + async with AsyncOAuth2Client("foo", scope="profile", transport=transport) as sess: + token = await sess.fetch_token(url) + assert token == default_token + + token = await sess.fetch_token(url, grant_type="client_credentials") + assert token == default_token + + +@pytest.mark.asyncio +async def test_cleans_previous_token_before_fetching_new_one(): + now = int(time.time()) + new_token = deepcopy(default_token) + past = now - 7200 + default_token["expires_at"] = past + new_token["expires_at"] = now + 3600 + url = "https://provider.test/token" + + transport = ASGITransport(AsyncMockDispatch(new_token)) + with mock.patch("time.time", lambda: now): + async with AsyncOAuth2Client( + "foo", token=default_token, transport=transport + ) as sess: + assert await sess.fetch_token(url) == new_token + + +def test_token_status(): + token = dict(access_token="a", token_type="bearer", expires_at=100) + sess = AsyncOAuth2Client("foo", token=token) + assert sess.token.is_expired() is True + + +@pytest.mark.asyncio +async def test_auto_refresh_token(): + async def _update_token(token, refresh_token=None, access_token=None): + assert refresh_token == "b" + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token="a", refresh_token="b", token_type="bearer", expires_at=100 + ) + + transport = ASGITransport(AsyncMockDispatch(default_token)) + async with AsyncOAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + transport=transport, + ) as sess: + await sess.get("https://resource.test/user") + assert update_token.called is True + + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + async with AsyncOAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + transport=transport, + ) as sess: + with pytest.raises(OAuthError): + await sess.get("https://resource.test/user") + + +@pytest.mark.asyncio +async def test_auto_refresh_token2(): + async def _update_token(token, refresh_token=None, access_token=None): + assert access_token == "a" + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + + transport = ASGITransport(AsyncMockDispatch(default_token)) + + async with AsyncOAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + grant_type="client_credentials", + transport=transport, + ) as client: + await client.get("https://resource.test/user") + assert update_token.called is False + + async with AsyncOAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + grant_type="client_credentials", + transport=transport, + ) as client: + await client.get("https://resource.test/user") + assert update_token.called is True + + +@pytest.mark.asyncio +async def test_auto_refresh_token3(): + async def _update_token(token, refresh_token=None, access_token=None): + assert access_token == "a" + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + + transport = ASGITransport(AsyncMockDispatch(default_token)) + + async with AsyncOAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + grant_type="client_credentials", + transport=transport, + ) as client: + await client.post("https://resource.test/user", json={"foo": "bar"}) + assert update_token.called is True + + +@pytest.mark.asyncio +async def test_auto_refresh_token4(): + async def _update_token(token, refresh_token=None, access_token=None): + # This test only makes sense if the expired token is refreshed + token["expires_at"] = int(time.time()) + 3600 + # artificial sleep to force other coroutines to wake + await asyncio.sleep(0.1) + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict(access_token="old", token_type="bearer", expires_at=100) + + transport = ASGITransport(AsyncMockDispatch(default_token)) + + async with AsyncOAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + grant_type="client_credentials", + transport=transport, + ) as client: + coroutines = [client.get("https://resource.test/user") for x in range(10)] + await asyncio.gather(*coroutines) + update_token.assert_called_once() + + +@pytest.mark.asyncio +async def test_revoke_token(): + answer = {"status": "ok"} + transport = ASGITransport(AsyncMockDispatch(answer)) + + async with AsyncOAuth2Client("a", transport=transport) as sess: + resp = await sess.revoke_token("https://provider.test/token", "hi") + assert resp.json() == answer + + resp = await sess.revoke_token( + "https://provider.test/token", "hi", token_type_hint="access_token" + ) + assert resp.json() == answer + + +@pytest.mark.asyncio +async def test_request_without_token(): + transport = ASGITransport(AsyncMockDispatch()) + async with AsyncOAuth2Client("a", transport=transport) as client: + with pytest.raises(OAuthError): + await client.get("https://provider.test/token") diff --git a/tests/clients/test_httpx/test_oauth1_client.py b/tests/clients/test_httpx/test_oauth1_client.py new file mode 100644 index 000000000..bd9b8fcbf --- /dev/null +++ b/tests/clients/test_httpx/test_oauth1_client.py @@ -0,0 +1,167 @@ +import pytest +from httpx import WSGITransport + +from authlib.integrations.httpx_client import SIGNATURE_TYPE_BODY +from authlib.integrations.httpx_client import SIGNATURE_TYPE_QUERY +from authlib.integrations.httpx_client import OAuth1Client +from authlib.integrations.httpx_client import OAuthError + +from ..wsgi_helper import MockDispatch + +oauth_url = "https://provider.test/oauth" + + +def test_fetch_request_token_via_header(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + def assert_func(request): + auth_header = request.headers.get("authorization") + assert 'oauth_consumer_key="id"' in auth_header + assert "oauth_signature=" in auth_header + + transport = WSGITransport(MockDispatch(request_token, assert_func=assert_func)) + with OAuth1Client("id", "secret", transport=transport) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +def test_fetch_request_token_via_body(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + def assert_func(request): + auth_header = request.headers.get("authorization") + assert auth_header is None + + content = request.form + assert content.get("oauth_consumer_key") == "id" + assert "oauth_signature" in content + + transport = WSGITransport(MockDispatch(request_token, assert_func=assert_func)) + + with OAuth1Client( + "id", + "secret", + signature_type=SIGNATURE_TYPE_BODY, + transport=transport, + ) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +def test_fetch_request_token_via_query(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + def assert_func(request): + auth_header = request.headers.get("authorization") + assert auth_header is None + + url = str(request.url) + assert "oauth_consumer_key=id" in url + assert "&oauth_signature=" in url + + transport = WSGITransport(MockDispatch(request_token, assert_func=assert_func)) + + with OAuth1Client( + "id", + "secret", + signature_type=SIGNATURE_TYPE_QUERY, + transport=transport, + ) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +def test_fetch_access_token(): + request_token = {"oauth_token": "1", "oauth_token_secret": "2"} + + def assert_func(request): + auth_header = request.headers.get("authorization") + assert 'oauth_verifier="d"' in auth_header + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert "oauth_signature=" in auth_header + + transport = WSGITransport(MockDispatch(request_token, assert_func=assert_func)) + with OAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + transport=transport, + ) as client: + with pytest.raises(OAuthError): + client.fetch_access_token(oauth_url) + + response = client.fetch_access_token(oauth_url, verifier="d") + + assert response == request_token + + +def test_get_via_header(): + transport = WSGITransport(MockDispatch(b"hello")) + with OAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + transport=transport, + ) as client: + response = client.get("https://resource.test/") + + assert response.content == b"hello" + request = response.request + auth_header = request.headers.get("authorization") + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert "oauth_signature=" in auth_header + + +def test_get_via_body(): + def assert_func(request): + content = request.form + assert content.get("oauth_token") == "foo" + assert content.get("oauth_consumer_key") == "id" + assert "oauth_signature" in content + + transport = WSGITransport(MockDispatch(b"hello", assert_func=assert_func)) + with OAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + signature_type=SIGNATURE_TYPE_BODY, + transport=transport, + ) as client: + response = client.post("https://resource.test/") + + assert response.content == b"hello" + + request = response.request + auth_header = request.headers.get("authorization") + assert auth_header is None + + +def test_get_via_query(): + transport = WSGITransport(MockDispatch(b"hello")) + with OAuth1Client( + "id", + "secret", + token="foo", + token_secret="bar", + signature_type=SIGNATURE_TYPE_QUERY, + transport=transport, + ) as client: + response = client.get("https://resource.test/") + + assert response.content == b"hello" + request = response.request + auth_header = request.headers.get("authorization") + assert auth_header is None + + url = str(request.url) + assert "oauth_token=foo" in url + assert "oauth_consumer_key=id" in url + assert "oauth_signature=" in url diff --git a/tests/clients/test_httpx/test_oauth2_client.py b/tests/clients/test_httpx/test_oauth2_client.py new file mode 100644 index 000000000..a1f6b6049 --- /dev/null +++ b/tests/clients/test_httpx/test_oauth2_client.py @@ -0,0 +1,371 @@ +import time +from copy import deepcopy +from unittest import mock + +import pytest +from httpx import WSGITransport + +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from authlib.integrations.httpx_client import OAuth2Client +from authlib.integrations.httpx_client import OAuthError + +from ..wsgi_helper import MockDispatch + +default_token = { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, +} + + +def assert_token_in_header(request): + token = "Bearer " + default_token["access_token"] + auth_header = request.headers.get("authorization") + assert auth_header == token + + +def assert_token_in_body(request): + content = request.data + content = content.decode() + assert content == "access_token={}".format(default_token["access_token"]) + + +def assert_token_in_uri(request): + assert default_token["access_token"] in str(request.url) + + +@pytest.mark.parametrize( + "assert_func, token_placement", + [ + (assert_token_in_header, "header"), + (assert_token_in_body, "body"), + (assert_token_in_uri, "uri"), + ], +) +def test_add_token_get_request(assert_func, token_placement): + transport = WSGITransport(MockDispatch({"a": "a"}, assert_func=assert_func)) + with OAuth2Client( + "foo", token=default_token, token_placement=token_placement, transport=transport + ) as client: + resp = client.get("https://provider.test") + + data = resp.json() + assert data["a"] == "a" + + +@pytest.mark.parametrize( + "assert_func, token_placement", + [ + (assert_token_in_header, "header"), + (assert_token_in_body, "body"), + (assert_token_in_uri, "uri"), + ], +) +def test_add_token_to_streaming_request(assert_func, token_placement): + transport = WSGITransport(MockDispatch({"a": "a"}, assert_func=assert_func)) + with OAuth2Client( + "foo", token=default_token, token_placement=token_placement, transport=transport + ) as client: + with client.stream("GET", "https://provider.test") as stream: + stream.read() + data = stream.json() + assert data["a"] == "a" + + +def test_create_authorization_url(): + url = "https://provider.test/authorize?foo=bar" + + sess = OAuth2Client(client_id="foo") + auth_url, state = sess.create_authorization_url(url) + assert state in auth_url + assert "client_id=foo" in auth_url + assert "response_type=code" in auth_url + + sess = OAuth2Client(client_id="foo", prompt="none") + auth_url, state = sess.create_authorization_url( + url, state="foo", redirect_uri="https://provider.test", scope="profile" + ) + assert state == "foo" + assert "provider.test" in auth_url + assert "profile" in auth_url + assert "prompt=none" in auth_url + + +def test_code_challenge(): + sess = OAuth2Client("foo", code_challenge_method="S256") + + url = "https://provider.test/authorize" + auth_url, _ = sess.create_authorization_url(url, code_verifier=generate_token(48)) + assert "code_challenge=" in auth_url + assert "code_challenge_method=S256" in auth_url + + +def test_token_from_fragment(): + sess = OAuth2Client("foo") + response_url = "https://provider.test/callback#" + url_encode(default_token.items()) + assert sess.token_from_fragment(response_url) == default_token + token = sess.fetch_token(authorization_response=response_url) + assert token == default_token + + +def test_fetch_token_post(): + url = "https://provider.test/token" + + def assert_func(request): + content = request.form + assert content.get("code") == "v" + assert content.get("client_id") == "foo" + assert content.get("grant_type") == "authorization_code" + + transport = WSGITransport(MockDispatch(default_token, assert_func=assert_func)) + with OAuth2Client("foo", transport=transport) as client: + token = client.fetch_token( + url, authorization_response="https://provider.test/?code=v" + ) + assert token == default_token + + with OAuth2Client( + "foo", token_endpoint_auth_method="none", transport=transport + ) as client: + token = client.fetch_token(url, code="v") + assert token == default_token + + transport = WSGITransport(MockDispatch({"error": "invalid_request"})) + with OAuth2Client("foo", transport=transport) as client: + with pytest.raises(OAuthError): + client.fetch_token(url) + + +def test_fetch_token_get(): + url = "https://provider.test/token" + + def assert_func(request): + url = str(request.url) + assert "code=v" in url + assert "client_id=" in url + assert "grant_type=authorization_code" in url + + transport = WSGITransport(MockDispatch(default_token, assert_func=assert_func)) + with OAuth2Client("foo", transport=transport) as client: + authorization_response = "https://provider.test/?code=v" + token = client.fetch_token( + url, authorization_response=authorization_response, method="GET" + ) + assert token == default_token + + with OAuth2Client( + "foo", token_endpoint_auth_method="none", transport=transport + ) as client: + token = client.fetch_token(url, code="v", method="GET") + assert token == default_token + + token = client.fetch_token(url + "?q=a", code="v", method="GET") + assert token == default_token + + +def test_token_auth_method_client_secret_post(): + url = "https://provider.test/token" + + def assert_func(request): + content = request.form + assert content.get("code") == "v" + assert content.get("client_id") == "foo" + assert content.get("client_secret") == "bar" + assert content.get("grant_type") == "authorization_code" + + transport = WSGITransport(MockDispatch(default_token, assert_func=assert_func)) + with OAuth2Client( + "foo", + "bar", + token_endpoint_auth_method="client_secret_post", + transport=transport, + ) as client: + token = client.fetch_token(url, code="v") + + assert token == default_token + + +def test_access_token_response_hook(): + url = "https://provider.test/token" + + def _access_token_response_hook(resp): + assert resp.json() == default_token + return resp + + access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) + transport = WSGITransport(MockDispatch(default_token)) + with OAuth2Client("foo", token=default_token, transport=transport) as sess: + sess.register_compliance_hook( + "access_token_response", access_token_response_hook + ) + assert sess.fetch_token(url) == default_token + assert access_token_response_hook.called is True + + +def test_password_grant_type(): + url = "https://provider.test/token" + + def assert_func(request): + content = request.form + assert content.get("username") == "v" + assert content.get("scope") == "profile" + assert content.get("grant_type") == "password" + + transport = WSGITransport(MockDispatch(default_token, assert_func=assert_func)) + with OAuth2Client("foo", scope="profile", transport=transport) as sess: + token = sess.fetch_token(url, username="v", password="v") + assert token == default_token + + token = sess.fetch_token(url, username="v", password="v", grant_type="password") + assert token == default_token + + +def test_client_credentials_type(): + url = "https://provider.test/token" + + def assert_func(request): + content = request.form + assert content.get("scope") == "profile" + assert content.get("grant_type") == "client_credentials" + + transport = WSGITransport(MockDispatch(default_token, assert_func=assert_func)) + with OAuth2Client("foo", scope="profile", transport=transport) as sess: + token = sess.fetch_token(url) + assert token == default_token + + token = sess.fetch_token(url, grant_type="client_credentials") + assert token == default_token + + +def test_cleans_previous_token_before_fetching_new_one(): + now = int(time.time()) + new_token = deepcopy(default_token) + past = now - 7200 + default_token["expires_at"] = past + new_token["expires_at"] = now + 3600 + url = "https://provider.test/token" + + transport = WSGITransport(MockDispatch(new_token)) + with mock.patch("time.time", lambda: now): + with OAuth2Client("foo", token=default_token, transport=transport) as sess: + assert sess.fetch_token(url) == new_token + + +def test_token_status(): + token = dict(access_token="a", token_type="bearer", expires_at=100) + sess = OAuth2Client("foo", token=token) + assert sess.token.is_expired() is True + + +def test_auto_refresh_token(): + def _update_token(token, refresh_token=None, access_token=None): + assert refresh_token == "b" + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token="a", refresh_token="b", token_type="bearer", expires_at=100 + ) + + transport = WSGITransport(MockDispatch(default_token)) + with OAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + transport=transport, + ) as sess: + sess.get("https://resource.test/user") + assert update_token.called is True + + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + with OAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + transport=transport, + ) as sess: + with pytest.raises(OAuthError): + sess.get("https://resource.test/user") + + +def test_auto_refresh_token2(): + def _update_token(token, refresh_token=None, access_token=None): + assert access_token == "a" + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + + transport = WSGITransport(MockDispatch(default_token)) + + with OAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + grant_type="client_credentials", + transport=transport, + ) as client: + client.get("https://resource.test/user") + assert update_token.called is False + + with OAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + grant_type="client_credentials", + transport=transport, + ) as client: + client.get("https://resource.test/user") + assert update_token.called is True + + +def test_auto_refresh_token3(): + def _update_token(token, refresh_token=None, access_token=None): + assert access_token == "a" + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + + transport = WSGITransport(MockDispatch(default_token)) + + with OAuth2Client( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + grant_type="client_credentials", + transport=transport, + ) as client: + client.post("https://resource.test/user", json={"foo": "bar"}) + assert update_token.called is True + + +def test_revoke_token(): + answer = {"status": "ok"} + transport = WSGITransport(MockDispatch(answer)) + + with OAuth2Client("a", transport=transport) as sess: + resp = sess.revoke_token("https://provider.test/token", "hi") + assert resp.json() == answer + + resp = sess.revoke_token( + "https://provider.test/token", "hi", token_type_hint="access_token" + ) + assert resp.json() == answer + + +def test_request_without_token(): + transport = WSGITransport(MockDispatch()) + with OAuth2Client("a", transport=transport) as client: + with pytest.raises(OAuthError): + client.get("https://provider.test/token") diff --git a/tests/py3/__init__.py b/tests/clients/test_requests/__init__.py similarity index 100% rename from tests/py3/__init__.py rename to tests/clients/test_requests/__init__.py diff --git a/tests/clients/test_requests/test_assertion_session.py b/tests/clients/test_requests/test_assertion_session.py new file mode 100644 index 000000000..6d93e564f --- /dev/null +++ b/tests/clients/test_requests/test_assertion_session.py @@ -0,0 +1,70 @@ +import time +from unittest import mock + +import pytest + +from authlib.integrations.requests_client import AssertionSession + + +@pytest.fixture +def token(): + return { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, + } + + +def test_refresh_token(token): + def verifier(r, **kwargs): + resp = mock.MagicMock() + resp.status_code = 200 + if r.url == "https://provider.test/token": + assert "assertion=" in r.body + resp.json = lambda: token + return resp + + sess = AssertionSession( + "https://provider.test/token", + issuer="foo", + subject="foo", + audience="foo", + alg="HS256", + key="secret", + ) + sess.send = verifier + sess.get("https://provider.test") + + # trigger more case + now = int(time.time()) + sess = AssertionSession( + "https://provider.test/token", + issuer="foo", + subject=None, + audience="foo", + issued_at=now, + expires_at=now + 3600, + header={"alg": "HS256"}, + key="secret", + scope="email", + claims={"test_mode": "true"}, + ) + sess.send = verifier + sess.get("https://provider.test") + # trigger for branch test case + sess.get("https://provider.test") + + +def test_without_alg(): + sess = AssertionSession( + "https://provider.test/token", + grant_type=AssertionSession.JWT_BEARER_GRANT_TYPE, + issuer="foo", + subject="foo", + audience="foo", + key="secret", + ) + with pytest.raises(ValueError): + sess.get("https://provider.test") diff --git a/tests/clients/test_requests/test_oauth1_session.py b/tests/clients/test_requests/test_oauth1_session.py new file mode 100644 index 000000000..968ad6552 --- /dev/null +++ b/tests/clients/test_requests/test_oauth1_session.py @@ -0,0 +1,286 @@ +from io import StringIO +from unittest import mock + +import pytest +import requests + +from authlib.common.encoding import to_unicode +from authlib.integrations.requests_client import OAuth1Session +from authlib.integrations.requests_client import OAuthError +from authlib.oauth1 import SIGNATURE_PLAINTEXT +from authlib.oauth1 import SIGNATURE_RSA_SHA1 +from authlib.oauth1 import SIGNATURE_TYPE_BODY +from authlib.oauth1 import SIGNATURE_TYPE_QUERY +from authlib.oauth1.rfc5849.util import escape + +from ..util import mock_text_response +from ..util import read_key_file + +TEST_RSA_OAUTH_SIGNATURE = ( + "Pko%2BFb4T1XGDE5DlLjuEMthVXjczqGi8qyfQ%2FSE405bBLEywint1tYNGN1me8h" + "JoXZMqyXy%2F%2FAzJ0ViRYRc7rDTaTYyjB%2Fct%2FFt8f4lb3e9LfGhgkwih%2FsH2w%3D%3D" +) + + +def test_no_client_id(): + with pytest.raises(ValueError): + OAuth1Session(None) + + +def test_signature_types(): + def verify_signature(getter): + def fake_send(r, **kwargs): + signature = to_unicode(getter(r)) + assert "oauth_signature" in signature + resp = mock.MagicMock(spec=requests.Response) + resp.cookies = [] + return resp + + return fake_send + + header = OAuth1Session("foo") + header.send = verify_signature(lambda r: r.headers["Authorization"]) + header.post("https://provider.test") + + query = OAuth1Session("foo", signature_type=SIGNATURE_TYPE_QUERY) + query.send = verify_signature(lambda r: r.url) + query.post("https://provider.test") + + body = OAuth1Session("foo", signature_type=SIGNATURE_TYPE_BODY) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + body.send = verify_signature(lambda r: r.body) + body.post("https://provider.test", headers=headers, data="") + + +@mock.patch("authlib.oauth1.rfc5849.client_auth.generate_timestamp") +@mock.patch("authlib.oauth1.rfc5849.client_auth.generate_nonce") +def test_signature_methods(generate_nonce, generate_timestamp): + generate_nonce.return_value = "abc" + generate_timestamp.return_value = "123" + + signature = ", ".join( + [ + 'OAuth oauth_nonce="abc"', + 'oauth_timestamp="123"', + 'oauth_version="1.0"', + 'oauth_signature_method="HMAC-SHA1"', + 'oauth_consumer_key="foo"', + 'oauth_signature="GuqiSr5%2FHajrrmc%2FFprUV4cCGbw%3D"', + ] + ) + auth = OAuth1Session("foo") + auth.send = verify_signature(signature) + auth.post("https://provider.test") + + signature = ( + "OAuth " + 'oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' + 'oauth_signature_method="PLAINTEXT", oauth_consumer_key="foo", ' + 'oauth_signature="%26"' + ) + auth = OAuth1Session("foo", signature_method=SIGNATURE_PLAINTEXT) + auth.send = verify_signature(signature) + auth.post("https://provider.test") + + signature = ( + "OAuth " + 'oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' + 'oauth_signature_method="RSA-SHA1", oauth_consumer_key="foo", ' + f'oauth_signature="{TEST_RSA_OAUTH_SIGNATURE}"' + ) + + rsa_key = read_key_file("rsa_private.pem") + auth = OAuth1Session("foo", signature_method=SIGNATURE_RSA_SHA1, rsa_key=rsa_key) + auth.send = verify_signature(signature) + auth.post("https://provider.test") + + +@mock.patch("authlib.oauth1.rfc5849.client_auth.generate_timestamp") +@mock.patch("authlib.oauth1.rfc5849.client_auth.generate_nonce") +def test_binary_upload(generate_nonce, generate_timestamp): + generate_nonce.return_value = "abc" + generate_timestamp.return_value = "123" + fake_xml = StringIO("hello world") + headers = {"Content-Type": "application/xml"} + + def fake_send(r, **kwargs): + auth_header = r.headers["Authorization"] + assert "oauth_body_hash" in auth_header + + auth = OAuth1Session("foo", force_include_body=True) + auth.send = fake_send + auth.post("https://provider.test", headers=headers, files=[("fake", fake_xml)]) + + +@mock.patch("authlib.oauth1.rfc5849.client_auth.generate_timestamp") +@mock.patch("authlib.oauth1.rfc5849.client_auth.generate_nonce") +def test_nonascii(generate_nonce, generate_timestamp): + generate_nonce.return_value = "abc" + generate_timestamp.return_value = "123" + signature = ( + 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' + 'oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", ' + 'oauth_signature="USkqQvV76SCKBewYI9cut6FfYcI%3D"' + ) + auth = OAuth1Session("foo") + auth.send = verify_signature(signature) + auth.post("https://provider.test?cjk=%E5%95%A6%E5%95%A6") + + +def test_redirect_uri(): + sess = OAuth1Session("foo") + assert sess.redirect_uri is None + url = "https://provider.test" + sess.redirect_uri = url + assert sess.redirect_uri == url + + +def test_set_token(): + sess = OAuth1Session("foo") + try: + sess.token = {} + except OAuthError as exc: + assert exc.error == "missing_token" + + sess.token = {"oauth_token": "a", "oauth_token_secret": "b"} + assert sess.token["oauth_verifier"] is None + sess.token = {"oauth_token": "a", "oauth_verifier": "c"} + assert sess.token["oauth_token_secret"] == "b" + assert sess.token["oauth_verifier"] == "c" + + sess.token = None + assert sess.token["oauth_token"] is None + assert sess.token["oauth_token_secret"] is None + assert sess.token["oauth_verifier"] is None + + +def test_create_authorization_url(): + auth = OAuth1Session("foo") + url = "https://provider.test/authorize" + token = "asluif023sf" + auth_url = auth.create_authorization_url(url, request_token=token) + assert auth_url == url + "?oauth_token=" + token + redirect_uri = "https://client.test/callback" + auth = OAuth1Session("foo", redirect_uri=redirect_uri) + auth_url = auth.create_authorization_url(url, request_token=token) + assert escape(redirect_uri) in auth_url + + +def test_parse_response_url(): + url = "https://provider.test/callback?oauth_token=foo&oauth_verifier=bar" + auth = OAuth1Session("foo") + resp = auth.parse_authorization_response(url) + assert resp["oauth_token"] == "foo" + assert resp["oauth_verifier"] == "bar" + for k, v in resp.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + +def test_fetch_request_token(): + auth = OAuth1Session("foo", realm="A") + auth.send = mock_text_response("oauth_token=foo") + resp = auth.fetch_request_token("https://provider.test/token") + assert resp["oauth_token"] == "foo" + for k, v in resp.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + resp = auth.fetch_request_token("https://provider.test/token") + assert resp["oauth_token"] == "foo" + + +def test_fetch_request_token_with_optional_arguments(): + auth = OAuth1Session("foo") + auth.send = mock_text_response("oauth_token=foo") + resp = auth.fetch_request_token( + "https://provider.test/token", verify=False, stream=True + ) + assert resp["oauth_token"] == "foo" + for k, v in resp.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + +def test_fetch_access_token(): + auth = OAuth1Session("foo", verifier="bar") + auth.send = mock_text_response("oauth_token=foo") + resp = auth.fetch_access_token("https://provider.test/token") + assert resp["oauth_token"] == "foo" + for k, v in resp.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + auth = OAuth1Session("foo", verifier="bar") + auth.send = mock_text_response('{"oauth_token":"foo"}') + resp = auth.fetch_access_token("https://provider.test/token") + assert resp["oauth_token"] == "foo" + + auth = OAuth1Session("foo") + auth.send = mock_text_response("oauth_token=foo") + resp = auth.fetch_access_token("https://provider.test/token", verifier="bar") + assert resp["oauth_token"] == "foo" + + +def test_fetch_access_token_with_optional_arguments(): + auth = OAuth1Session("foo", verifier="bar") + auth.send = mock_text_response("oauth_token=foo") + resp = auth.fetch_access_token( + "https://provider.test/token", verify=False, stream=True + ) + assert resp["oauth_token"] == "foo" + for k, v in resp.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + +def _test_fetch_access_token_raises_error(session): + """Assert that an error is being raised whenever there's no verifier + passed in to the client. + """ + session.send = mock_text_response("oauth_token=foo") + with pytest.raises(OAuthError, match="missing_verifier"): + session.fetch_access_token("https://provider.test/token") + + +def test_fetch_token_invalid_response(): + auth = OAuth1Session("foo") + auth.send = mock_text_response("not valid urlencoded response!") + with pytest.raises(ValueError): + auth.fetch_request_token("https://provider.test/token") + + for code in (400, 401, 403): + auth.send = mock_text_response("valid=response", code) + with pytest.raises(OAuthError, match="fetch_token_denied"): + auth.fetch_request_token("https://provider.test/token") + + +def test_fetch_access_token_missing_verifier(): + _test_fetch_access_token_raises_error(OAuth1Session("foo")) + + +def test_fetch_access_token_has_verifier_is_none(): + session = OAuth1Session("foo") + session.auth.verifier = None + _test_fetch_access_token_raises_error(session) + + +def verify_signature(signature): + def fake_send(r, **kwargs): + auth_header = to_unicode(r.headers["Authorization"]) + # RSA signatures are non-deterministic, so we only check the prefix for RSA-SHA1 + if 'oauth_signature_method="RSA-SHA1"' in signature: + signature_prefix = ( + signature.split('oauth_signature="')[0] + 'oauth_signature="' + ) + auth_prefix = ( + auth_header.split('oauth_signature="')[0] + 'oauth_signature="' + ) + assert auth_prefix == signature_prefix + else: + assert auth_header == signature + resp = mock.MagicMock(spec=requests.Response) + resp.cookies = [] + return resp + + return fake_send diff --git a/tests/clients/test_requests/test_oauth2_session.py b/tests/clients/test_requests/test_oauth2_session.py new file mode 100644 index 000000000..a6e0feb19 --- /dev/null +++ b/tests/clients/test_requests/test_oauth2_session.py @@ -0,0 +1,622 @@ +import time +from copy import deepcopy +from unittest import mock + +import pytest +from joserfc.jwk import OctKey +from joserfc.jwk import RSAKey + +from authlib.common.security import generate_token +from authlib.common.urls import add_params_to_uri +from authlib.common.urls import url_encode +from authlib.integrations.requests_client import OAuth2Session +from authlib.integrations.requests_client import OAuthError +from authlib.oauth2.rfc6749 import MismatchingStateException +from authlib.oauth2.rfc7523 import ClientSecretJWT +from authlib.oauth2.rfc7523 import PrivateKeyJWT + +from ..util import read_key_file + + +def mock_json_response(payload): + def fake_send(r, **kwargs): + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: payload + return resp + + return fake_send + + +def mock_assertion_response(token, session): + def fake_send(r, **kwargs): + assert "client_assertion=" in r.body + assert "client_assertion_type=" in r.body + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + session.send = fake_send + + +@pytest.fixture +def token(): + return { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, + } + + +def test_invalid_token_type(token): + token = { + "token_type": "invalid", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, + } + with OAuth2Session("foo", token=token) as sess: + with pytest.raises(OAuthError): + sess.get("https://provider.test") + + +def test_add_token_to_header(token): + expected_header = "Bearer " + token["access_token"] + + def verifier(r, **kwargs): + auth_header = r.headers.get("Authorization", None) + assert auth_header == expected_header + resp = mock.MagicMock() + return resp + + sess = OAuth2Session(client_id="foo", token=token) + sess.send = verifier + sess.get("https://provider.test") + + +def test_add_token_to_body(token): + def verifier(r, **kwargs): + assert token["access_token"] in r.body + resp = mock.MagicMock() + return resp + + sess = OAuth2Session(client_id="foo", token=token, token_placement="body") + sess.send = verifier + sess.post("https://provider.test") + + +def test_add_token_to_uri(token): + def verifier(r, **kwargs): + assert token["access_token"] in r.url + resp = mock.MagicMock() + return resp + + sess = OAuth2Session(client_id="foo", token=token, token_placement="uri") + sess.send = verifier + sess.get("https://provider.test") + + +def test_create_authorization_url(): + url = "https://provider.test/authorize?foo=bar" + + sess = OAuth2Session(client_id="foo") + auth_url, state = sess.create_authorization_url(url) + assert state in auth_url + assert "foo" in auth_url + assert "response_type=code" in auth_url + + sess = OAuth2Session(client_id="foo", prompt="none") + auth_url, state = sess.create_authorization_url( + url, state="foo", redirect_uri="https://provider.test", scope="profile" + ) + assert state == "foo" + assert "provider.test" in auth_url + assert "profile" in auth_url + assert "prompt=none" in auth_url + + +def test_code_challenge(): + sess = OAuth2Session(client_id="foo", code_challenge_method="S256") + + url = "https://provider.test/authorize" + auth_url, _ = sess.create_authorization_url(url, code_verifier=generate_token(48)) + assert "code_challenge" in auth_url + assert "code_challenge_method=S256" in auth_url + + +def test_token_from_fragment(token): + sess = OAuth2Session("foo") + response_url = "https://provider.test/callback#" + url_encode(token.items()) + assert sess.token_from_fragment(response_url) == token + token = sess.fetch_token(authorization_response=response_url) + assert token == token + + +def test_fetch_token_post(token): + url = "https://provider.test/token" + + def fake_send(r, **kwargs): + assert "code=v" in r.body + assert "client_id=" in r.body + assert "grant_type=authorization_code" in r.body + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + sess = OAuth2Session(client_id="foo") + sess.send = fake_send + assert ( + sess.fetch_token(url, authorization_response="https://provider.test/?code=v") + == token + ) + + sess = OAuth2Session( + client_id="foo", + token_endpoint_auth_method="none", + ) + sess.send = fake_send + token = sess.fetch_token(url, code="v") + assert token == token + + error = {"error": "invalid_request"} + sess = OAuth2Session(client_id="foo", token=token) + sess.send = mock_json_response(error) + with pytest.raises(OAuthError): + sess.fetch_access_token(url) + + +def test_fetch_token_get(token): + url = "https://provider.test/token" + + def fake_send(r, **kwargs): + assert "code=v" in r.url + assert "grant_type=authorization_code" in r.url + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + sess = OAuth2Session(client_id="foo") + sess.send = fake_send + token = sess.fetch_token( + url, authorization_response="https://provider.test/?code=v", method="GET" + ) + assert token == token + + sess = OAuth2Session( + client_id="foo", + token_endpoint_auth_method="none", + ) + sess.send = fake_send + token = sess.fetch_token(url, code="v", method="GET") + assert token == token + + token = sess.fetch_token(url + "?q=a", code="v", method="GET") + assert token == token + + +def test_token_auth_method_client_secret_post(token): + url = "https://provider.test/token" + + def fake_send(r, **kwargs): + assert "code=v" in r.body + assert "client_id=" in r.body + assert "client_secret=bar" in r.body + assert "grant_type=authorization_code" in r.body + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + sess = OAuth2Session( + client_id="foo", + client_secret="bar", + token_endpoint_auth_method="client_secret_post", + ) + sess.send = fake_send + token = sess.fetch_token(url, code="v") + assert token == token + + +def test_access_token_response_hook(token): + url = "https://provider.test/token" + + def access_token_response_hook(resp): + assert resp.json() == token + return resp + + sess = OAuth2Session(client_id="foo", token=token) + sess.register_compliance_hook("access_token_response", access_token_response_hook) + sess.send = mock_json_response(token) + assert sess.fetch_token(url) == token + + +def test_password_grant_type(token): + url = "https://provider.test/token" + + def fake_send(r, **kwargs): + assert "username=v" in r.body + assert "grant_type=password" in r.body + assert "scope=profile" in r.body + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + sess = OAuth2Session(client_id="foo", scope="profile") + sess.send = fake_send + token = sess.fetch_token(url, username="v", password="v") + assert token == token + + +def test_client_credentials_type(token): + url = "https://provider.test/token" + + def fake_send(r, **kwargs): + assert "grant_type=client_credentials" in r.body + assert "scope=profile" in r.body + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + sess = OAuth2Session( + client_id="foo", + client_secret="v", + scope="profile", + ) + sess.send = fake_send + token = sess.fetch_token(url) + assert token == token + + +def test_cleans_previous_token_before_fetching_new_one(token): + """Makes sure the previous token is cleaned before fetching a new one. + The reason behind it is that, if the previous token is expired, this + method shouldn't fail with a TokenExpiredError, since it's attempting + to get a new one (which shouldn't be expired). + """ + now = int(time.time()) + new_token = deepcopy(token) + past = now - 7200 + token["expires_at"] = past + new_token["expires_at"] = now + 3600 + url = "https://provider.test/token" + + with mock.patch("time.time", lambda: now): + sess = OAuth2Session(client_id="foo", token=token) + sess.send = mock_json_response(new_token) + assert sess.fetch_token(url) == new_token + + +def test_mis_match_state(token): + sess = OAuth2Session("foo") + with pytest.raises(MismatchingStateException): + sess.fetch_token( + "https://provider.test/token", + authorization_response="https://provider.test/no-state?code=abc", + state="somestate", + ) + + +def test_token_status(): + token = dict(access_token="a", token_type="bearer", expires_at=100) + sess = OAuth2Session("foo", token=token) + + assert sess.token.is_expired + + +def test_token_status2(): + token = dict(access_token="a", token_type="bearer", expires_in=10) + sess = OAuth2Session("foo", token=token, leeway=15) + + assert sess.token.is_expired(sess.leeway) + + +def test_token_status3(): + token = dict(access_token="a", token_type="bearer", expires_in=10) + sess = OAuth2Session("foo", token=token, leeway=5) + + assert not sess.token.is_expired(sess.leeway) + + +def test_expires_in_used_when_expires_at_unparseable(): + """Test that expires_in is used as fallback when expires_at is unparsable.""" + token = dict( + access_token="a", + token_type="bearer", + expires_in=3600, # 1 hour from now + expires_at="2024-01-01T00:00:00Z", # Unparsable - should fall back to expires_in + ) + sess = OAuth2Session("foo", token=token) + + # The token should use expires_in since expires_at is unparsable + # So it should be considered expired with leeway > 3600 + assert sess.token.is_expired(leeway=3700) is True + # And not expired with leeway < 3600 + assert sess.token.is_expired(leeway=0) is False + # expires_at should be calculated from expires_in + assert isinstance(sess.token["expires_at"], int) + + +def test_unparseable_expires_at_returns_none(): + """Test that is_expired returns None when expires_at is unparsable and no expires_in.""" + token = dict( + access_token="a", + token_type="bearer", + expires_at="2024-01-01T00:00:00Z", # Unparsable date string + ) + sess = OAuth2Session("foo", token=token) + + # Should return None since we can't determine expiration + assert sess.token.is_expired() is None + # The unparsable expires_at should be preserved in the token + assert sess.token["expires_at"] == "2024-01-01T00:00:00Z" + # No expires_in should be calculated + assert "expires_in" not in sess.token + + +def test_token_expired(): + token = dict(access_token="a", token_type="bearer", expires_at=100) + sess = OAuth2Session("foo", token=token) + with pytest.raises(OAuthError): + sess.get( + "https://provider.test/token", + ) + + +def test_missing_token(): + sess = OAuth2Session("foo") + with pytest.raises(OAuthError): + sess.get( + "https://provider.test/token", + ) + + +def test_register_compliance_hook(token): + sess = OAuth2Session("foo") + with pytest.raises(ValueError): + sess.register_compliance_hook( + "invalid_hook", + lambda o: o, + ) + + def protected_request(url, headers, data): + assert "Authorization" in headers + return url, headers, data + + sess = OAuth2Session("foo", token=token) + sess.register_compliance_hook( + "protected_request", + protected_request, + ) + sess.send = mock_json_response({"name": "a"}) + sess.get("https://resource.test/user") + + +def test_auto_refresh_token(token): + def _update_token(token_, refresh_token=None, access_token=None): + assert refresh_token == "b" + assert token == token_ + + update_token = mock.Mock(side_effect=_update_token) + old_token = dict( + access_token="a", refresh_token="b", token_type="bearer", expires_at=100 + ) + sess = OAuth2Session( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + update_token=update_token, + ) + sess.send = mock_json_response(token) + sess.get("https://resource.test/user") + assert update_token.called + + +def test_auto_refresh_token2(token): + def _update_token(token_, refresh_token=None, access_token=None): + assert access_token == "a" + assert token == token_ + + update_token = mock.Mock(side_effect=_update_token) + old_token = dict(access_token="a", token_type="bearer", expires_at=100) + + sess = OAuth2Session( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + grant_type="client_credentials", + ) + sess.send = mock_json_response(token) + sess.get("https://resource.test/user") + assert not update_token.called + + sess = OAuth2Session( + "foo", + token=old_token, + token_endpoint="https://provider.test/token", + grant_type="client_credentials", + update_token=update_token, + ) + sess.send = mock_json_response(token) + sess.get("https://resource.test/user") + assert update_token.called + + +def test_revoke_token(): + sess = OAuth2Session("a") + answer = {"status": "ok"} + sess.send = mock_json_response(answer) + resp = sess.revoke_token("https://provider.test/token", "hi") + assert resp.json() == answer + resp = sess.revoke_token( + "https://provider.test/token", "hi", token_type_hint="access_token" + ) + assert resp.json() == answer + + def revoke_token_request(url, headers, data): + assert url == "https://provider.test/token" + return url, headers, data + + sess.register_compliance_hook( + "revoke_token_request", + revoke_token_request, + ) + sess.revoke_token( + "https://provider.test/token", "hi", body="", token_type_hint="access_token" + ) + + +def test_introspect_token(): + sess = OAuth2Session("a") + answer = { + "active": True, + "client_id": "l238j323ds-23ij4", + "username": "jdoe", + "scope": "read write dolphin", + "sub": "Z5O3upPC88QrAjx00dis", + "aud": "https://resource.test/resource", + "iss": "https://provider.test/", + "exp": 1419356238, + "iat": 1419350238, + } + sess.send = mock_json_response(answer) + resp = sess.introspect_token("https://provider.test/token", "hi") + assert resp.json() == answer + + +def test_client_secret_jwt(token): + sess = OAuth2Session("id", "secret", token_endpoint_auth_method="client_secret_jwt") + sess.register_client_auth_method(ClientSecretJWT()) + + mock_assertion_response(token, sess) + token = sess.fetch_token("https://provider.test/token") + assert token == token + + +def test_client_secret_jwt2(token): + sess = OAuth2Session( + "id", + OctKey.import_key("secret"), + token_endpoint_auth_method=ClientSecretJWT(), + ) + mock_assertion_response(token, sess) + token = sess.fetch_token("https://provider.test/token") + assert token == token + + +def test_private_key_jwt(token): + client_secret = read_key_file("rsa_private.pem") + sess = OAuth2Session( + "id", client_secret, token_endpoint_auth_method="private_key_jwt" + ) + sess.register_client_auth_method(PrivateKeyJWT()) + mock_assertion_response(token, sess) + token = sess.fetch_token("https://provider.test/token") + assert token == token + + +def test_private_key_jwt2(token): + client_secret = RSAKey.import_key(read_key_file("rsa_private.pem")) + sess = OAuth2Session( + "id", + client_secret, + token_endpoint_auth_method=PrivateKeyJWT(), + ) + mock_assertion_response(token, sess) + token = sess.fetch_token("https://provider.test/token") + assert token == token + + +def test_custom_client_auth_method(token): + def auth_client(client, method, uri, headers, body): + uri = add_params_to_uri( + uri, + [ + ("client_id", client.client_id), + ("client_secret", client.client_secret), + ], + ) + uri = uri + "&" + body + body = "" + return uri, headers, body + + sess = OAuth2Session("id", "secret", token_endpoint_auth_method="client_secret_uri") + sess.register_client_auth_method(("client_secret_uri", auth_client)) + + def fake_send(r, **kwargs): + assert "client_id=" in r.url + assert "client_secret=" in r.url + resp = mock.MagicMock() + resp.status_code = 200 + resp.json = lambda: token + return resp + + sess.send = fake_send + token = sess.fetch_token("https://provider.test/token") + assert token == token + + +def test_use_client_token_auth(token): + import requests + + expected_header = "Bearer " + token["access_token"] + + def verifier(r, **kwargs): + auth_header = r.headers.get("Authorization", None) + assert auth_header == expected_header + resp = mock.MagicMock() + return resp + + client = OAuth2Session(client_id="foo", token=token) + + sess = requests.Session() + sess.send = verifier + sess.get("https://provider.test", auth=client.token_auth) + + +def test_use_default_request_timeout(token): + expected_timeout = 15 + + def verifier(r, **kwargs): + timeout = kwargs.get("timeout") + assert timeout == expected_timeout + resp = mock.MagicMock() + return resp + + client = OAuth2Session( + client_id="foo", + token=token, + default_timeout=expected_timeout, + ) + + client.send = verifier + client.request("GET", "https://provider.test", withhold_token=False) + + +def test_override_default_request_timeout(token): + default_timeout = 15 + expected_timeout = 10 + + def verifier(r, **kwargs): + timeout = kwargs.get("timeout") + assert timeout == expected_timeout + resp = mock.MagicMock() + return resp + + client = OAuth2Session( + client_id="foo", + token=token, + default_timeout=default_timeout, + ) + + client.send = verifier + client.request( + "GET", "https://provider.test", withhold_token=False, timeout=expected_timeout + ) diff --git a/tests/py3/test_httpx_client/__init__.py b/tests/clients/test_starlette/__init__.py similarity index 100% rename from tests/py3/test_httpx_client/__init__.py rename to tests/clients/test_starlette/__init__.py diff --git a/tests/clients/test_starlette/test_oauth_client.py b/tests/clients/test_starlette/test_oauth_client.py new file mode 100644 index 000000000..b4a6b655a --- /dev/null +++ b/tests/clients/test_starlette/test_oauth_client.py @@ -0,0 +1,850 @@ +import json + +import pytest +from httpx import ASGITransport +from starlette.config import Config +from starlette.datastructures import URL +from starlette.requests import Request + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.integrations.starlette_client import OAuth +from authlib.integrations.starlette_client import OAuthError + +from ..asgi_helper import AsyncPathMapDispatch +from ..util import get_bearer_token + + +class AsyncDummyCache: + """Simple async cache for testing.""" + + def __init__(self): + self._data = {} + + async def get(self, key): + return self._data.get(key) + + async def set(self, key, value, expires_in=None): + self._data[key] = value + + async def delete(self, key): + self._data.pop(key, None) + + +def test_register_remote_app(): + oauth = OAuth() + with pytest.raises(AttributeError): + assert oauth.dev.name == "dev" + + oauth.register( + "dev", + client_id="dev", + client_secret="dev", + ) + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + +def test_register_with_config(): + config = Config(environ={"DEV_CLIENT_ID": "dev"}) + oauth = OAuth(config) + oauth.register("dev") + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + +def test_register_with_overwrite(): + config = Config(environ={"DEV_CLIENT_ID": "dev"}) + oauth = OAuth(config) + oauth.register("dev", client_id="not-dev", overwrite=True) + assert oauth.dev.name == "dev" + assert oauth.dev.client_id == "dev" + + +@pytest.mark.asyncio +async def test_oauth1_authorize(): + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/request-token": {"body": "oauth_token=foo&oauth_verifier=baz"}, + "/token": {"body": "oauth_token=a&oauth_token_secret=b"}, + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + request_token_url="https://provider.test/request-token", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.authorize_redirect(req, "https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "oauth_token=foo" in url + assert "_state_dev_foo" in req.session + req.scope["query_string"] = "oauth_token=foo&oauth_verifier=baz" + token = await client.authorize_access_token(req) + assert token["oauth_token"] == "a" + + +@pytest.mark.asyncio +async def test_oauth2_authorize(): + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch({"/token": {"body": get_bearer_token()}}) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.authorize_redirect(req, "https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "state=" in url + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + + assert f"_state_dev_{state}" in req.session + + req_scope.update( + { + "path": "/", + "query_string": f"code=a&state={state}", + "session": req.session, + } + ) + req = Request(req_scope) + token = await client.authorize_access_token(req) + assert token["access_token"] == "a" + + +class _FakeAsyncCache: + """Minimal async cache implementing the authlib framework cache protocol.""" + + def __init__(self): + self.store = {} + + async def get(self, key): + return self.store.get(key) + + async def set(self, key, value, expires=None): + self.store[key] = value + + async def delete(self, key): + self.store.pop(key, None) + + +@pytest.mark.asyncio +async def test_oauth2_authorize_csrf_with_cache(): + """When a cache is configured, the state must still be bound to the + session that initiated the flow. Otherwise an attacker can start an + authorization request, stop before the callback, and trick a victim into + completing the flow — logging the victim into the attacker's account + (RFC 6749 §10.12).""" + transport = ASGITransport( + AsyncPathMapDispatch({"/token": {"body": get_bearer_token()}}) + ) + oauth = OAuth(cache=_FakeAsyncCache()) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + + # Attacker initiates an auth flow from their own session. + attacker_req = Request({"type": "http", "session": {}}) + resp = await client.authorize_redirect(attacker_req, "https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + + # Victim is tricked into hitting the callback URL. The victim's browser + # carries a *different* session — they never initiated this flow. + victim_req = Request( + { + "type": "http", + "path": "/", + "query_string": f"code=a&state={state}".encode(), + "session": {}, + } + ) + with pytest.raises(OAuthError): + await client.authorize_access_token(victim_req) + + +@pytest.mark.asyncio +async def test_oauth2_authorize_access_denied(): + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch({"/token": {"body": get_bearer_token()}}) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + + req = Request( + { + "type": "http", + "session": {}, + "path": "/", + "query_string": "error=access_denied&error_description=Not+Allowed", + } + ) + with pytest.raises(OAuthError): + await client.authorize_access_token(req) + + +@pytest.mark.asyncio +async def test_oauth2_authorize_code_challenge(): + transport = ASGITransport( + AsyncPathMapDispatch({"/token": {"body": get_bearer_token()}}) + ) + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "code_challenge_method": "S256", + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + + resp = await client.authorize_redirect( + req, redirect_uri="https://client.test/callback" + ) + assert resp.status_code == 302 + + url = resp.headers.get("Location") + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + state_data = req.session[f"_state_dev_{state}"]["data"] + + verifier = state_data["code_verifier"] + assert verifier is not None + + req_scope.update( + { + "path": "/", + "query_string": f"code=a&state={state}".encode(), + "session": req.session, + } + ) + req = Request(req_scope) + + token = await client.authorize_access_token(req) + assert token["access_token"] == "a" + + +@pytest.mark.asyncio +async def test_with_fetch_token_in_register(): + async def fetch_token(request): + return {"access_token": "dev", "token_type": "bearer"} + + transport = ASGITransport(AsyncPathMapDispatch({"/user": {"body": {"sub": "123"}}})) + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + fetch_token=fetch_token, + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.get("/user", request=req) + assert resp.json()["sub"] == "123" + + +@pytest.mark.asyncio +async def test_with_fetch_token_in_oauth(): + async def fetch_token(name, request): + return {"access_token": "dev", "token_type": "bearer"} + + transport = ASGITransport(AsyncPathMapDispatch({"/user": {"body": {"sub": "123"}}})) + oauth = OAuth(fetch_token=fetch_token) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.get("/user", request=req) + assert resp.json()["sub"] == "123" + + +@pytest.mark.asyncio +async def test_request_withhold_token(): + oauth = OAuth() + transport = ASGITransport(AsyncPathMapDispatch({"/user": {"body": {"sub": "123"}}})) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.get("/user", request=req, withhold_token=True) + assert resp.json()["sub"] == "123" + + +@pytest.mark.asyncio +async def test_oauth2_authorize_no_url(): + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + ) + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + with pytest.raises(RuntimeError): + await client.create_authorization_url(req) + + +@pytest.mark.asyncio +async def test_oauth2_fetch_metadata(): + def assert_headers(req): + assert "Authlib/" in req.headers.get("user-agent", "") + + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + path_maps={ + "/.well-known/openid-configuration": { + "body": { + "authorization_endpoint": "https://provider.test/authorize", + "jwks_uri": "https://provider.test/.well-known/keys", + } + }, + "/.well-known/keys": {"body": {"keys": []}}, + }, + side_effects={ + "/.well-known/openid-configuration": assert_headers, + "/.well-known/keys": assert_headers, + }, + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + await client.fetch_jwk_set() + + +@pytest.mark.asyncio +async def test_oauth2_authorize_with_metadata(): + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "authorization_endpoint": "https://provider.test/authorize" + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.authorize_redirect(req, "https://client.test/callback") + assert resp.status_code == 302 + + +@pytest.mark.asyncio +async def test_oauth2_authorize_form_post_callback(): + """Test that POST callbacks (form_post response mode) work properly.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch({"/token": {"body": get_bearer_token()}}) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + client_kwargs={ + "transport": transport, + }, + ) + + req = Request({"type": "http", "session": {}}) + resp = await client.authorize_redirect(req, "https://client.test/callback") + url = resp.headers.get("Location") + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + + req_scope_post = { + "type": "http", + "method": "POST", + "path": "/callback", + "query_string": b"", + "headers": [(b"content-type", b"application/x-www-form-urlencoded")], + "session": req.session, + } + + async def receive(): + return { + "type": "http.request", + "body": f"code=test_code&state={state}".encode(), + } + + req_post = Request(req_scope_post, receive=receive) + + token = await client.authorize_access_token(req_post) + assert token["access_token"] == "a" + + +@pytest.mark.asyncio +async def test_logout_redirect(): + """Test logout_redirect generates correct URL with state stored in session.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.logout_redirect( + req, + post_logout_redirect_uri="https://client.test/logged-out", + id_token_hint="fake.id.token", + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "https://provider.test/logout" in url + assert "id_token_hint=fake.id.token" in url + assert "post_logout_redirect_uri" in url + assert "state=" in url + + # Verify state is stored in session + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + assert f"_state_dev_{state}" in req.session + + +@pytest.mark.asyncio +async def test_logout_redirect_without_redirect_uri(): + """Test logout_redirect omits state when no post_logout_redirect_uri is provided.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.logout_redirect(req, id_token_hint="fake.id.token") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "id_token_hint=fake.id.token" in url + assert "state" not in url + + +@pytest.mark.asyncio +async def test_logout_redirect_missing_endpoint(): + """Test logout_redirect raises RuntimeError when end_session_endpoint is missing.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "authorization_endpoint": "https://provider.test/authorize", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + with pytest.raises(RuntimeError, match='Missing "end_session_endpoint"'): + await client.logout_redirect(req) + + +@pytest.mark.asyncio +async def test_validate_logout_response(): + """Test validate_logout_response verifies state and returns stored data.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.logout_redirect( + req, + post_logout_redirect_uri="https://client.test/logged-out", + ) + url = resp.headers.get("Location") + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + + req_scope2 = { + "type": "http", + "session": req.session, + "query_string": f"state={state}", + } + req2 = Request(req_scope2) + state_data = await client.validate_logout_response(req2) + assert state_data["post_logout_redirect_uri"] == "https://client.test/logged-out" + # State should be cleared from session + assert f"_state_dev_{state}" not in req2.session + + +@pytest.mark.asyncio +async def test_validate_logout_response_missing_state(): + """Test validate_logout_response raises OAuthError when state is missing.""" + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + req = Request({"type": "http", "session": {}, "query_string": ""}) + with pytest.raises(OAuthError, match='Missing "state" parameter'): + await client.validate_logout_response(req) + + +@pytest.mark.asyncio +async def test_validate_logout_response_invalid_state(): + """Test validate_logout_response raises OAuthError when state is invalid.""" + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + ) + + req = Request( + {"type": "http", "session": {}, "query_string": "state=invalid-state"} + ) + with pytest.raises(OAuthError, match='Invalid "state" parameter'): + await client.validate_logout_response(req) + + +@pytest.mark.asyncio +async def test_logout_redirect_with_cache(): + """Test logout_redirect stores state in cache instead of session.""" + cache = AsyncDummyCache() + oauth = OAuth(cache=cache) + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.logout_redirect( + req, + post_logout_redirect_uri="https://client.test/logged-out", + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + + # With cache, data is in cache, not in session + cache_key = f"_state_dev_{state}" + cached_data = await cache.get(cache_key) + assert cached_data is not None + assert ( + json.loads(cached_data)["data"]["post_logout_redirect_uri"] + == "https://client.test/logged-out" + ) + + +@pytest.mark.asyncio +async def test_logout_redirect_with_url_object(): + """Test logout_redirect handles URL objects for post_logout_redirect_uri.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + # Pass URL object instead of string + redirect_uri = URL("https://client.test/logged-out") + resp = await client.logout_redirect( + req, + post_logout_redirect_uri=redirect_uri, + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "post_logout_redirect_uri=https%3A%2F%2Fclient.test%2Flogged-out" in url + + +@pytest.mark.asyncio +async def test_validate_logout_response_with_cache(): + """Test validate_logout_response retrieves state from cache.""" + cache = AsyncDummyCache() + oauth = OAuth(cache=cache) + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.logout_redirect( + req, + post_logout_redirect_uri="https://client.test/logged-out", + ) + url = resp.headers.get("Location") + params = dict(url_decode(urlparse.urlparse(url).query)) + state = params["state"] + + # Validate the response — use the same session to prove continuity + req2 = Request( + {"type": "http", "session": req.session, "query_string": f"state={state}"} + ) + state_data = await client.validate_logout_response(req2) + assert state_data["post_logout_redirect_uri"] == "https://client.test/logged-out" + + # Cache should be cleared + cache_key = f"_state_dev_{state}" + assert await cache.get(cache_key) is None + + +@pytest.mark.asyncio +async def test_logout_redirect_with_extra_params(): + """Test logout_redirect includes optional params: client_id, logout_hint, ui_locales.""" + oauth = OAuth() + transport = ASGITransport( + AsyncPathMapDispatch( + { + "/.well-known/openid-configuration": { + "body": { + "issuer": "https://provider.test", + "end_session_endpoint": "https://provider.test/logout", + } + } + } + ) + ) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + server_metadata_url="https://provider.test/.well-known/openid-configuration", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + resp = await client.logout_redirect( + req, + post_logout_redirect_uri="https://client.test/logged-out", + client_id="dev", + logout_hint="user@example.com", + ui_locales="fr", + ) + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "client_id=dev" in url + assert "logout_hint=user%40example.com" in url + assert "ui_locales=fr" in url diff --git a/tests/clients/test_starlette/test_user_mixin.py b/tests/clients/test_starlette/test_user_mixin.py new file mode 100644 index 000000000..80f4df0c6 --- /dev/null +++ b/tests/clients/test_starlette/test_user_mixin.py @@ -0,0 +1,158 @@ +import time + +import pytest +from httpx import ASGITransport +from joserfc import jwk +from joserfc import jwt +from joserfc.errors import InvalidClaimError +from joserfc.jwk import KeySet +from starlette.requests import Request + +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core.grants.util import create_half_hash + +from ..asgi_helper import AsyncPathMapDispatch +from ..util import get_bearer_token +from ..util import read_key_file + +secret_key = jwk.import_key("test-oct-secret", "oct", {"kid": "f"}) + + +async def run_fetch_userinfo(payload): + oauth = OAuth() + + async def fetch_token(request): + return get_bearer_token() + + transport = ASGITransport(AsyncPathMapDispatch({"/userinfo": {"body": payload}})) + + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=fetch_token, + userinfo_endpoint="https://provider.test/userinfo", + client_kwargs={ + "transport": transport, + }, + ) + + req_scope = {"type": "http", "session": {}} + req = Request(req_scope) + user = await client.userinfo(request=req) + assert user.sub == "123" + + +@pytest.mark.asyncio +async def test_fetch_userinfo(): + await run_fetch_userinfo({"sub": "123"}) + + +@pytest.mark.asyncio +async def test_parse_id_token(): + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "nonce": "n", + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256"}, claims, secret_key) + token["id_token"] = id_token + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + jwks={"keys": [secret_key.as_dict()]}, + issuer="https://provider.test", + id_token_signing_alg_values_supported=["HS256", "RS256"], + ) + user = await client.parse_id_token(token, nonce="n") + assert user.sub == "123" + + claims_options = {"iss": {"value": "https://provider.test"}} + user = await client.parse_id_token(token, nonce="n", claims_options=claims_options) + assert user.sub == "123" + + with pytest.raises(InvalidClaimError): + claims_options = {"iss": {"value": "https://wrong-provider.test"}} + await client.parse_id_token(token, nonce="n", claims_options=claims_options) + + +@pytest.mark.asyncio +async def test_runtime_error_fetch_jwks_uri(): + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "nonce": "n", + "at_hash": create_half_hash(token["access_token"], "HS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "HS256"}, claims, secret_key) + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + issuer="https://provider.test", + id_token_signing_alg_values_supported=["HS256"], + ) + req_scope = {"type": "http", "session": {"_dev_authlib_nonce_": "n"}} + req = Request(req_scope) + token["id_token"] = id_token + with pytest.raises(RuntimeError): + await client.parse_id_token(req, token) + + +@pytest.mark.asyncio +async def test_force_fetch_jwks_uri(): + secret_keys = KeySet.import_key_set(read_key_file("jwks_private.json")) + token = get_bearer_token() + now = int(time.time()) + claims = { + "sub": "123", + "iss": "https://provider.test", + "aud": "dev", + "iat": now, + "auth_time": now, + "exp": now + 3600, + "nonce": "n", + "at_hash": create_half_hash(token["access_token"], "RS256").decode("utf-8"), + } + id_token = jwt.encode({"alg": "RS256"}, claims, secret_keys) + token["id_token"] = id_token + + transport = ASGITransport( + AsyncPathMapDispatch({"/jwks": {"body": read_key_file("jwks_public.json")}}) + ) + + oauth = OAuth() + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + fetch_token=get_bearer_token, + jwks={"keys": [secret_key.as_dict()]}, + jwks_uri="https://provider.test/jwks", + issuer="https://provider.test", + client_kwargs={ + "transport": transport, + }, + ) + user = await client.parse_id_token(token, nonce="n") + assert user.sub == "123" diff --git a/tests/client_base.py b/tests/clients/util.py similarity index 56% rename from tests/client_base.py rename to tests/clients/util.py index 4d67ad28b..d53348350 100644 --- a/tests/client_base.py +++ b/tests/clients/util.py @@ -1,15 +1,19 @@ -from __future__ import unicode_literals, print_function +import json +import os import time +from unittest import mock + import requests -import mock +ROOT = os.path.abspath(os.path.dirname(__file__)) -def mock_json_response(payload): - def fake_send(r, **kwargs): - resp = mock.MagicMock() - resp.json = lambda: payload - return resp - return fake_send + +def read_key_file(name): + file_path = os.path.join(ROOT, "keys", name) + with open(file_path) as f: + if name.endswith(".json"): + return json.load(f) + return f.read() def mock_text_response(body, status_code=200): @@ -19,6 +23,7 @@ def fake_send(r, **kwargs): resp.text = body resp.status_code = status_code return resp + return fake_send @@ -35,9 +40,9 @@ def mock_send_value(body, status_code=200): def get_bearer_token(): return { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "3600", + "expires_at": int(time.time()) + 3600, } diff --git a/tests/clients/wsgi_helper.py b/tests/clients/wsgi_helper.py new file mode 100644 index 000000000..80b5a560e --- /dev/null +++ b/tests/clients/wsgi_helper.py @@ -0,0 +1,35 @@ +import json + +from werkzeug.wrappers import Request as WSGIRequest +from werkzeug.wrappers import Response as WSGIResponse + + +class MockDispatch: + def __init__(self, body=b"", status_code=200, headers=None, assert_func=None): + if headers is None: + headers = {} + if isinstance(body, dict): + body = json.dumps(body).encode() + headers["Content-Type"] = "application/json" + else: + if isinstance(body, str): + body = body.encode() + headers["Content-Type"] = "application/x-www-form-urlencoded" + + self.body = body + self.status_code = status_code + self.headers = headers + self.assert_func = assert_func + + def __call__(self, environ, start_response): + request = WSGIRequest(environ) + + if self.assert_func: + self.assert_func(request) + + response = WSGIResponse( + status=self.status_code, + response=self.body, + headers=self.headers, + ) + return response(environ, start_response) diff --git a/tests/py3/test_starlette_client/__init__.py b/tests/core/__init__.py similarity index 100% rename from tests/py3/test_starlette_client/__init__.py rename to tests/core/__init__.py diff --git a/tests/core/test_jose/test_jwe.py b/tests/core/test_jose/test_jwe.py deleted file mode 100644 index 332500973..000000000 --- a/tests/core/test_jose/test_jwe.py +++ /dev/null @@ -1,273 +0,0 @@ -import os -import unittest -from authlib.jose import errors -from authlib.jose import OctKey, OKPKey -from authlib.jose import JsonWebEncryption -from authlib.common.encoding import urlsafe_b64encode -from tests.util import read_file_path - - -class JWETest(unittest.TestCase): - def test_not_enough_segments(self): - s = 'a.b.c' - jwe = JsonWebEncryption() - self.assertRaises( - errors.DecodeError, - jwe.deserialize_compact, - s, None - ) - - def test_invalid_header(self): - jwe = JsonWebEncryption() - public_key = read_file_path('rsa_public.pem') - self.assertRaises( - errors.MissingAlgorithmError, - jwe.serialize_compact, {}, 'a', public_key - ) - self.assertRaises( - errors.UnsupportedAlgorithmError, - jwe.serialize_compact, {'alg': 'invalid'}, 'a', public_key - ) - self.assertRaises( - errors.MissingEncryptionAlgorithmError, - jwe.serialize_compact, {'alg': 'RSA-OAEP'}, 'a', public_key - ) - self.assertRaises( - errors.UnsupportedEncryptionAlgorithmError, - jwe.serialize_compact, {'alg': 'RSA-OAEP', 'enc': 'invalid'}, - 'a', public_key - ) - self.assertRaises( - errors.UnsupportedCompressionAlgorithmError, - jwe.serialize_compact, - {'alg': 'RSA-OAEP', 'enc': 'A256GCM', 'zip': 'invalid'}, - 'a', public_key - ) - - def test_not_supported_alg(self): - public_key = read_file_path('rsa_public.pem') - private_key = read_file_path('rsa_private.pem') - - jwe = JsonWebEncryption() - s = jwe.serialize_compact( - {'alg': 'RSA-OAEP', 'enc': 'A256GCM'}, - 'hello', public_key - ) - - jwe = JsonWebEncryption(algorithms=['RSA1_5', 'A256GCM']) - self.assertRaises( - errors.UnsupportedAlgorithmError, - jwe.serialize_compact, - {'alg': 'RSA-OAEP', 'enc': 'A256GCM'}, - 'hello', public_key - ) - self.assertRaises( - errors.UnsupportedCompressionAlgorithmError, - jwe.serialize_compact, - {'alg': 'RSA1_5', 'enc': 'A256GCM', 'zip': 'DEF'}, - 'hello', public_key - ) - self.assertRaises( - errors.UnsupportedAlgorithmError, - jwe.deserialize_compact, - s, private_key, - ) - - jwe = JsonWebEncryption(algorithms=['RSA-OAEP', 'A192GCM']) - self.assertRaises( - errors.UnsupportedEncryptionAlgorithmError, - jwe.serialize_compact, - {'alg': 'RSA-OAEP', 'enc': 'A256GCM'}, - 'hello', public_key - ) - self.assertRaises( - errors.UnsupportedCompressionAlgorithmError, - jwe.serialize_compact, - {'alg': 'RSA-OAEP', 'enc': 'A192GCM', 'zip': 'DEF'}, - 'hello', public_key - ) - self.assertRaises( - errors.UnsupportedEncryptionAlgorithmError, - jwe.deserialize_compact, - s, private_key, - ) - - def test_compact_rsa(self): - jwe = JsonWebEncryption() - s = jwe.serialize_compact( - {'alg': 'RSA-OAEP', 'enc': 'A256GCM'}, - 'hello', - read_file_path('rsa_public.pem') - ) - data = jwe.deserialize_compact(s, read_file_path('rsa_private.pem')) - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'RSA-OAEP') - - def test_with_zip_header(self): - jwe = JsonWebEncryption() - s = jwe.serialize_compact( - {'alg': 'RSA-OAEP', 'enc': 'A128CBC-HS256', 'zip': 'DEF'}, - 'hello', - read_file_path('rsa_public.pem') - ) - data = jwe.deserialize_compact(s, read_file_path('rsa_private.pem')) - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'RSA-OAEP') - - def test_aes_jwe(self): - jwe = JsonWebEncryption() - sizes = [128, 192, 256] - _enc_choices = [ - 'A128CBC-HS256', 'A192CBC-HS384', 'A256CBC-HS512', - 'A128GCM', 'A192GCM', 'A256GCM' - ] - for s in sizes: - alg = 'A{}KW'.format(s) - key = os.urandom(s // 8) - for enc in _enc_choices: - protected = {'alg': alg, 'enc': enc} - data = jwe.serialize_compact(protected, b'hello', key) - rv = jwe.deserialize_compact(data, key) - self.assertEqual(rv['payload'], b'hello') - - def test_ase_jwe_invalid_key(self): - jwe = JsonWebEncryption() - protected = {'alg': 'A128KW', 'enc': 'A128GCM'} - self.assertRaises( - ValueError, - jwe.serialize_compact, - protected, b'hello', b'invalid-key' - ) - - def test_aes_gcm_jwe(self): - jwe = JsonWebEncryption() - sizes = [128, 192, 256] - _enc_choices = [ - 'A128CBC-HS256', 'A192CBC-HS384', 'A256CBC-HS512', - 'A128GCM', 'A192GCM', 'A256GCM' - ] - for s in sizes: - alg = 'A{}GCMKW'.format(s) - key = os.urandom(s // 8) - for enc in _enc_choices: - protected = {'alg': alg, 'enc': enc} - data = jwe.serialize_compact(protected, b'hello', key) - rv = jwe.deserialize_compact(data, key) - self.assertEqual(rv['payload'], b'hello') - - def test_ase_gcm_jwe_invalid_key(self): - jwe = JsonWebEncryption() - protected = {'alg': 'A128GCMKW', 'enc': 'A128GCM'} - self.assertRaises( - ValueError, - jwe.serialize_compact, - protected, b'hello', b'invalid-key' - ) - - def test_ecdh_key_agreement_computation(self): - # https://tools.ietf.org/html/rfc7518#appendix-C - alice_key = { - "kty": "EC", - "crv": "P-256", - "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", - "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", - "d": "0_NxaRPUMQoAJt50Gz8YiTr8gRTwyEaCumd-MToTmIo" - } - bob_key = { - "kty": "EC", - "crv": "P-256", - "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", - "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", - "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw" - } - headers = { - "alg": "ECDH-ES", - "enc": "A128GCM", - "apu": "QWxpY2U", - "apv": "Qm9i", - } - alg = JsonWebEncryption.ALG_REGISTRY['ECDH-ES'] - key = alg.prepare_key(alice_key) - bob_key = alg.prepare_key(bob_key) - public_key = bob_key.get_op_key('wrapKey') - dk = alg.deliver(key, public_key, headers, 128) - self.assertEqual(urlsafe_b64encode(dk), b'VqqN6vgjbSBcIijNcacQGg') - - def test_ecdh_es_jwe(self): - jwe = JsonWebEncryption() - key = { - "kty": "EC", - "crv": "P-256", - "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", - "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", - "d": "0_NxaRPUMQoAJt50Gz8YiTr8gRTwyEaCumd-MToTmIo" - } - for alg in ["ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]: - protected = {'alg': alg, 'enc': 'A128GCM'} - data = jwe.serialize_compact(protected, b'hello', key) - rv = jwe.deserialize_compact(data, key) - self.assertEqual(rv['payload'], b'hello') - - def test_ecdh_es_with_okp(self): - jwe = JsonWebEncryption() - key = OKPKey.generate_key('X25519', is_private=True) - for alg in ["ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]: - protected = {'alg': alg, 'enc': 'A128GCM'} - data = jwe.serialize_compact(protected, b'hello', key) - rv = jwe.deserialize_compact(data, key) - self.assertEqual(rv['payload'], b'hello') - - def test_ecdh_es_raise(self): - jwe = JsonWebEncryption() - protected = {'alg': 'ECDH-ES', 'enc': 'A128GCM'} - key = { - "kty": "EC", - "crv": "P-256", - "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", - "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", - } - data = jwe.serialize_compact(protected, b'hello', key) - self.assertRaises(ValueError, jwe.deserialize_compact, data, key) - - key = OKPKey.generate_key('Ed25519', is_private=True) - self.assertRaises( - ValueError, - jwe.serialize_compact, - protected, b'hello', key - ) - - def test_dir_alg(self): - jwe = JsonWebEncryption() - key = OctKey.generate_key(128, is_private=True) - protected = {'alg': 'dir', 'enc': 'A128GCM'} - data = jwe.serialize_compact(protected, b'hello', key) - rv = jwe.deserialize_compact(data, key) - self.assertEqual(rv['payload'], b'hello') - - key2 = OctKey.generate_key(256, is_private=True) - self.assertRaises(ValueError, jwe.deserialize_compact, data, key2) - - self.assertRaises( - ValueError, - jwe.serialize_compact, - protected, b'hello', key2 - ) - - def test_dir_alg_c20p(self): - jwe = JsonWebEncryption() - key = OctKey.generate_key(256, is_private=True) - protected = {'alg': 'dir', 'enc': 'C20P'} - data = jwe.serialize_compact(protected, b'hello', key) - rv = jwe.deserialize_compact(data, key) - self.assertEqual(rv['payload'], b'hello') - - key2 = OctKey.generate_key(128, is_private=True) - self.assertRaises(ValueError, jwe.deserialize_compact, data, key2) - - self.assertRaises( - ValueError, - jwe.serialize_compact, - protected, b'hello', key2 - ) diff --git a/tests/core/test_jose/test_jwk.py b/tests/core/test_jose/test_jwk.py deleted file mode 100644 index 2b679c1c8..000000000 --- a/tests/core/test_jose/test_jwk.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -from authlib.jose import jwk, JsonWebKey, KeySet -from authlib.jose import RSAKey, ECKey, OKPKey -from authlib.common.encoding import base64_to_int -from tests.util import read_file_path - -RSA_PRIVATE_KEY = read_file_path('jwk_private.json') - - -class JWKTest(unittest.TestCase): - def assertBase64IntEqual(self, x, y): - self.assertEqual(base64_to_int(x), base64_to_int(y)) - - def test_ec_public_key(self): - # https://tools.ietf.org/html/rfc7520#section-3.1 - obj = read_file_path('ec_public.json') - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertEqual(new_obj['crv'], obj['crv']) - self.assertBase64IntEqual(new_obj['x'], obj['x']) - self.assertBase64IntEqual(new_obj['y'], obj['y']) - self.assertEqual(key.as_json()[0], '{') - - def test_ec_private_key(self): - # https://tools.ietf.org/html/rfc7520#section-3.2 - obj = read_file_path('ec_private.json') - key = jwk.loads(obj) - new_obj = jwk.dumps(key, 'EC') - self.assertEqual(new_obj['crv'], obj['crv']) - self.assertBase64IntEqual(new_obj['x'], obj['x']) - self.assertBase64IntEqual(new_obj['y'], obj['y']) - self.assertBase64IntEqual(new_obj['d'], obj['d']) - - def test_invalid_ec(self): - self.assertRaises(ValueError, jwk.loads, {'kty': 'EC'}) - self.assertRaises(ValueError, jwk.dumps, '', 'EC') - - def test_rsa_public_key(self): - # https://tools.ietf.org/html/rfc7520#section-3.3 - obj = read_file_path('jwk_public.json') - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertBase64IntEqual(new_obj['n'], obj['n']) - self.assertBase64IntEqual(new_obj['e'], obj['e']) - - def test_rsa_private_key(self): - # https://tools.ietf.org/html/rfc7520#section-3.4 - obj = RSA_PRIVATE_KEY - key = jwk.loads(obj) - new_obj = jwk.dumps(key, 'RSA') - self.assertBase64IntEqual(new_obj['n'], obj['n']) - self.assertBase64IntEqual(new_obj['e'], obj['e']) - self.assertBase64IntEqual(new_obj['d'], obj['d']) - self.assertBase64IntEqual(new_obj['p'], obj['p']) - self.assertBase64IntEqual(new_obj['q'], obj['q']) - self.assertBase64IntEqual(new_obj['dp'], obj['dp']) - self.assertBase64IntEqual(new_obj['dq'], obj['dq']) - self.assertBase64IntEqual(new_obj['qi'], obj['qi']) - - def test_rsa_private_key2(self): - obj = { - "kty": "RSA", - "kid": "bilbo.baggins@hobbiton.example", - "use": "sig", - "n": RSA_PRIVATE_KEY['n'], - 'd': RSA_PRIVATE_KEY['d'], - "e": "AQAB" - } - key = jwk.loads(obj) - new_obj = jwk.dumps(key.raw_key, 'RSA') - self.assertBase64IntEqual(new_obj['n'], obj['n']) - self.assertBase64IntEqual(new_obj['e'], obj['e']) - self.assertBase64IntEqual(new_obj['d'], obj['d']) - self.assertBase64IntEqual(new_obj['p'], RSA_PRIVATE_KEY['p']) - self.assertBase64IntEqual(new_obj['q'], RSA_PRIVATE_KEY['q']) - self.assertBase64IntEqual(new_obj['dp'], RSA_PRIVATE_KEY['dp']) - self.assertBase64IntEqual(new_obj['dq'], RSA_PRIVATE_KEY['dq']) - self.assertBase64IntEqual(new_obj['qi'], RSA_PRIVATE_KEY['qi']) - - def test_invalid_rsa(self): - obj = { - "kty": "RSA", - "kid": "bilbo.baggins@hobbiton.example", - "use": "sig", - "n": RSA_PRIVATE_KEY['n'], - 'd': RSA_PRIVATE_KEY['d'], - 'p': RSA_PRIVATE_KEY['p'], - "e": "AQAB" - } - self.assertRaises(ValueError, jwk.loads, obj) - self.assertRaises(ValueError, jwk.loads, {'kty': 'RSA'}) - self.assertRaises(ValueError, jwk.dumps, '', 'RSA') - - def test_dumps_okp_public_key(self): - key = read_file_path('ed25519-ssh.pub') - self.assertRaises(ValueError, jwk.dumps, key) - - obj = jwk.dumps(key, 'OKP') - self.assertEqual(obj['kty'], 'OKP') - self.assertEqual(obj['crv'], 'Ed25519') - - key = read_file_path('ed25519-pub.pem') - obj = jwk.dumps(key, 'OKP') - self.assertEqual(obj['kty'], 'OKP') - self.assertEqual(obj['crv'], 'Ed25519') - - def test_loads_okp_public_key(self): - obj = { - "x": "AD9E0JYnpV-OxZbd8aN1t4z71Vtf6JcJC7TYHT0HDbg", - "crv": "Ed25519", - "kty": "OKP" - } - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertEqual(obj['x'], new_obj['x']) - - def test_dumps_okp_private_key(self): - key = read_file_path('ed25519-pkcs8.pem') - obj = jwk.dumps(key, 'OKP') - self.assertEqual(obj['kty'], 'OKP') - self.assertEqual(obj['crv'], 'Ed25519') - self.assertIn('d', obj) - - def test_loads_okp_private_key(self): - obj = { - 'x': '11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo', - 'd': 'nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A', - 'crv': 'Ed25519', - 'kty': 'OKP' - } - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertEqual(obj['d'], new_obj['d']) - - def test_mac_computation(self): - # https://tools.ietf.org/html/rfc7520#section-3.5 - obj = { - "kty": "oct", - "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", - "use": "sig", - "alg": "HS256", - "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" - } - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertEqual(obj['k'], new_obj['k']) - self.assertIn('use', new_obj) - - new_obj = jwk.dumps(key, use='sig') - self.assertEqual(new_obj['use'], 'sig') - - def test_jwk_loads(self): - self.assertRaises(ValueError, jwk.loads, {}) - self.assertRaises(ValueError, jwk.loads, {}, 'k') - - obj = { - "kty": "oct", - "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", - "use": "sig", - "alg": "HS256", - "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" - } - self.assertRaises(ValueError, jwk.loads, [obj], 'invalid-kid') - - def test_jwk_dumps_ssh(self): - key = read_file_path('ssh_public.pem') - obj = jwk.dumps(key, kty='RSA') - self.assertEqual(obj['kty'], 'RSA') - - def test_thumbprint(self): - # https://tools.ietf.org/html/rfc7638#section-3.1 - data = read_file_path('thumbprint_example.json') - key = JsonWebKey.import_key(data) - expected = 'NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs' - self.assertEqual(key.thumbprint(), expected) - - def test_key_set(self): - key = RSAKey.generate_key(is_private=True) - key_set = KeySet([key]) - obj = key_set.as_dict()['keys'][0] - self.assertIn('kid', obj) - self.assertEqual(key_set.as_json()[0], '{') - - def test_rsa_key_generate_pem(self): - self.assertRaises(ValueError, RSAKey.generate_key, 256) - self.assertRaises(ValueError, RSAKey.generate_key, 2001) - - key1 = RSAKey.generate_key(is_private=True) - self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) - self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - - key2 = RSAKey.generate_key(is_private=False) - self.assertRaises(ValueError, key2.as_pem, True) - self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) - - def test_ec_key_generate_pem(self): - self.assertRaises(ValueError, ECKey.generate_key, 'invalid') - - key1 = ECKey.generate_key('P-384', is_private=True) - self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) - self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - - key2 = ECKey.generate_key('P-256', is_private=False) - self.assertRaises(ValueError, key2.as_pem, True) - self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) - - def test_okp_key_generate_pem(self): - self.assertRaises(ValueError, OKPKey.generate_key, 'invalid') - - key1 = OKPKey.generate_key('Ed25519', is_private=True) - self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) - self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - - key2 = OKPKey.generate_key('X25519', is_private=False) - self.assertRaises(ValueError, key2.as_pem, True) - self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) diff --git a/tests/core/test_jose/test_jws.py b/tests/core/test_jose/test_jws.py deleted file mode 100644 index 026f86732..000000000 --- a/tests/core/test_jose/test_jws.py +++ /dev/null @@ -1,195 +0,0 @@ -import unittest -import json -from authlib.jose import JsonWebSignature -from authlib.jose import errors -from tests.util import read_file_path - - -class JWSTest(unittest.TestCase): - def test_invalid_input(self): - jws = JsonWebSignature() - self.assertRaises(errors.DecodeError, jws.deserialize, 'a', 'k') - self.assertRaises(errors.DecodeError, jws.deserialize, 'a.b.c', 'k') - self.assertRaises( - errors.DecodeError, jws.deserialize, 'YQ.YQ.YQ', 'k') # a - self.assertRaises( - errors.DecodeError, jws.deserialize, 'W10.a.YQ', 'k') # [] - self.assertRaises( - errors.DecodeError, jws.deserialize, 'e30.a.YQ', 'k') # {} - self.assertRaises( - errors.DecodeError, jws.deserialize, 'eyJhbGciOiJzIn0.a.YQ', 'k') - self.assertRaises( - errors.DecodeError, jws.deserialize, 'eyJhbGciOiJzIn0.YQ.a', 'k') - - def test_invalid_alg(self): - jws = JsonWebSignature() - self.assertRaises( - errors.UnsupportedAlgorithmError, - jws.deserialize, 'eyJhbGciOiJzIn0.YQ.YQ', 'k') - self.assertRaises( - errors.MissingAlgorithmError, - jws.serialize, {}, '', 'k' - ) - self.assertRaises( - errors.UnsupportedAlgorithmError, - jws.serialize, {'alg': 's'}, '', 'k' - ) - - def test_bad_signature(self): - jws = JsonWebSignature() - s = 'eyJhbGciOiJIUzI1NiJ9.YQ.YQ' - self.assertRaises(errors.BadSignatureError, jws.deserialize, s, 'k') - - def test_not_supported_alg(self): - jws = JsonWebSignature(algorithms=['HS256']) - s = jws.serialize({'alg': 'HS256'}, 'hello', 'secret') - - jws = JsonWebSignature(algorithms=['RS256']) - self.assertRaises( - errors.UnsupportedAlgorithmError, - lambda: jws.serialize({'alg': 'HS256'}, 'hello', 'secret') - ) - - self.assertRaises( - errors.UnsupportedAlgorithmError, - jws.deserialize, - s, 'secret' - ) - - def test_compact_jws(self): - jws = JsonWebSignature(algorithms=['HS256']) - s = jws.serialize({'alg': 'HS256'}, 'hello', 'secret') - data = jws.deserialize(s, 'secret') - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'HS256') - self.assertNotIn('signature', data) - - def test_compact_rsa(self): - jws = JsonWebSignature() - private_key = read_file_path('rsa_private.pem') - public_key = read_file_path('rsa_public.pem') - s = jws.serialize({'alg': 'RS256'}, 'hello', private_key) - data = jws.deserialize(s, public_key) - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'RS256') - - # can deserialize with private key - data2 = jws.deserialize(s, private_key) - self.assertEqual(data, data2) - - ssh_pub_key = read_file_path('ssh_public.pem') - self.assertRaises(errors.BadSignatureError, jws.deserialize, s, ssh_pub_key) - - def test_compact_rsa_pss(self): - jws = JsonWebSignature() - private_key = read_file_path('rsa_private.pem') - public_key = read_file_path('rsa_public.pem') - s = jws.serialize({'alg': 'PS256'}, 'hello', private_key) - data = jws.deserialize(s, public_key) - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'PS256') - ssh_pub_key = read_file_path('ssh_public.pem') - self.assertRaises(errors.BadSignatureError, jws.deserialize, s, ssh_pub_key) - - def test_compact_none(self): - jws = JsonWebSignature() - s = jws.serialize({'alg': 'none'}, 'hello', '') - self.assertRaises(errors.BadSignatureError, jws.deserialize, s, '') - - def test_flattened_json_jws(self): - jws = JsonWebSignature() - protected = {'alg': 'HS256'} - header = {'protected': protected, 'header': {'kid': 'a'}} - s = jws.serialize(header, 'hello', 'secret') - self.assertIsInstance(s, dict) - - data = jws.deserialize(s, 'secret') - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'HS256') - self.assertNotIn('protected', data) - - def test_nested_json_jws(self): - jws = JsonWebSignature() - protected = {'alg': 'HS256'} - header = {'protected': protected, 'header': {'kid': 'a'}} - s = jws.serialize([header], 'hello', 'secret') - self.assertIsInstance(s, dict) - self.assertIn('signatures', s) - - data = jws.deserialize(s, 'secret') - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header[0]['alg'], 'HS256') - self.assertNotIn('signatures', data) - - # test bad signature - self.assertRaises(errors.BadSignatureError, jws.deserialize, s, 'f') - - def test_function_key(self): - protected = {'alg': 'HS256'} - header = [ - {'protected': protected, 'header': {'kid': 'a'}}, - {'protected': protected, 'header': {'kid': 'b'}}, - ] - - def load_key(header, payload): - self.assertEqual(payload, b'hello') - kid = header.get('kid') - if kid == 'a': - return 'secret-a' - return 'secret-b' - - jws = JsonWebSignature() - s = jws.serialize(header, b'hello', load_key) - self.assertIsInstance(s, dict) - self.assertIn('signatures', s) - - data = jws.deserialize(json.dumps(s), load_key) - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header[0]['alg'], 'HS256') - self.assertNotIn('signature', data) - - def test_fail_deserialize_json(self): - jws = JsonWebSignature() - self.assertRaises(errors.DecodeError, jws.deserialize_json, None, '') - self.assertRaises(errors.DecodeError, jws.deserialize_json, '[]', '') - self.assertRaises(errors.DecodeError, jws.deserialize_json, '{}', '') - - # missing protected - s = json.dumps({'payload': 'YQ'}) - self.assertRaises(errors.DecodeError, jws.deserialize_json, s, '') - - # missing signature - s = json.dumps({'payload': 'YQ', 'protected': 'YQ'}) - self.assertRaises(errors.DecodeError, jws.deserialize_json, s, '') - - def test_validate_header(self): - jws = JsonWebSignature(private_headers=[]) - protected = {'alg': 'HS256', 'invalid': 'k'} - header = {'protected': protected, 'header': {'kid': 'a'}} - self.assertRaises( - errors.InvalidHeaderParameterName, - jws.serialize, header, b'hello', 'secret' - ) - jws = JsonWebSignature(private_headers=['invalid']) - s = jws.serialize(header, b'hello', 'secret') - self.assertIsInstance(s, dict) - - jws = JsonWebSignature() - s = jws.serialize(header, b'hello', 'secret') - self.assertIsInstance(s, dict) - - def test_EdDSA_alg(self): - jws = JsonWebSignature(algorithms=['EdDSA']) - private_key = read_file_path('ed25519-pkcs8.pem') - public_key = read_file_path('ed25519-pub.pem') - s = jws.serialize({'alg': 'EdDSA'}, 'hello', private_key) - data = jws.deserialize(s, public_key) - header, payload = data['header'], data['payload'] - self.assertEqual(payload, b'hello') - self.assertEqual(header['alg'], 'EdDSA') diff --git a/tests/core/test_jose/test_jwt.py b/tests/core/test_jose/test_jwt.py deleted file mode 100644 index 106149ea6..000000000 --- a/tests/core/test_jose/test_jwt.py +++ /dev/null @@ -1,188 +0,0 @@ -import unittest -import datetime -from authlib.jose import errors -from authlib.jose import JsonWebToken, JWTClaims, jwt -from authlib.jose.errors import UnsupportedAlgorithmError, InvalidUseError -from tests.util import read_file_path - - -class JWTTest(unittest.TestCase): - def test_init_algorithms(self): - _jwt = JsonWebToken(['RS256']) - self.assertRaises( - UnsupportedAlgorithmError, - _jwt.encode, {'alg': 'HS256'}, {}, 'k' - ) - - _jwt = JsonWebToken('RS256') - self.assertRaises( - UnsupportedAlgorithmError, - _jwt.encode, {'alg': 'HS256'}, {}, 'k' - ) - - def test_encode_sensitive_data(self): - # check=False won't raise error - jwt.encode({'alg': 'HS256'}, {'password': ''}, 'k', check=False) - self.assertRaises( - errors.InsecureClaimError, - jwt.encode, {'alg': 'HS256'}, {'password': ''}, 'k' - ) - self.assertRaises( - errors.InsecureClaimError, - jwt.encode, {'alg': 'HS256'}, {'text': '4242424242424242'}, 'k' - ) - - def test_encode_datetime(self): - now = datetime.datetime.utcnow() - id_token = jwt.encode({'alg': 'HS256'}, {'exp': now}, 'k') - claims = jwt.decode(id_token, 'k') - self.assertIsInstance(claims.exp, int) - - def test_validate_essential_claims(self): - id_token = jwt.encode({'alg': 'HS256'}, {'iss': 'foo'}, 'k') - claims_options = { - 'iss': { - 'essential': True, - 'values': ['foo'] - } - } - claims = jwt.decode(id_token, 'k', claims_options=claims_options) - claims.validate() - - claims.options = {'sub': {'essential': True}} - self.assertRaises( - errors.MissingClaimError, - claims.validate - ) - - def test_attribute_error(self): - claims = JWTClaims({'iss': 'foo'}, {'alg': 'HS256'}) - self.assertRaises(AttributeError, lambda: claims.invalid) - - def test_invalid_values(self): - id_token = jwt.encode({'alg': 'HS256'}, {'iss': 'foo'}, 'k') - claims_options = {'iss': {'values': ['bar']}} - claims = jwt.decode(id_token, 'k', claims_options=claims_options) - self.assertRaises( - errors.InvalidClaimError, - claims.validate, - ) - claims.options = {'iss': {'value': 'bar'}} - self.assertRaises( - errors.InvalidClaimError, - claims.validate, - ) - - def test_validate_aud(self): - id_token = jwt.encode({'alg': 'HS256'}, {'aud': 'foo'}, 'k') - claims_options = { - 'aud': { - 'essential': True, - 'value': 'foo' - } - } - claims = jwt.decode(id_token, 'k', claims_options=claims_options) - claims.validate() - - claims.options = { - 'aud': {'values': ['bar']} - } - self.assertRaises( - errors.InvalidClaimError, - claims.validate - ) - - id_token = jwt.encode({'alg': 'HS256'}, {'aud': ['foo', 'bar']}, 'k') - claims = jwt.decode(id_token, 'k', claims_options=claims_options) - claims.validate() - # no validate - claims.options = {'aud': {'values': []}} - claims.validate() - - def test_validate_exp(self): - id_token = jwt.encode({'alg': 'HS256'}, {'exp': 'invalid'}, 'k') - claims = jwt.decode(id_token, 'k') - self.assertRaises( - errors.InvalidClaimError, - claims.validate - ) - - id_token = jwt.encode({'alg': 'HS256'}, {'exp': 1234}, 'k') - claims = jwt.decode(id_token, 'k') - self.assertRaises( - errors.ExpiredTokenError, - claims.validate - ) - - def test_validate_nbf(self): - id_token = jwt.encode({'alg': 'HS256'}, {'nbf': 'invalid'}, 'k') - claims = jwt.decode(id_token, 'k') - self.assertRaises( - errors.InvalidClaimError, - claims.validate - ) - - id_token = jwt.encode({'alg': 'HS256'}, {'nbf': 1234}, 'k') - claims = jwt.decode(id_token, 'k') - claims.validate() - - id_token = jwt.encode({'alg': 'HS256'}, {'nbf': 1234}, 'k') - claims = jwt.decode(id_token, 'k') - self.assertRaises( - errors.InvalidTokenError, - claims.validate, 123 - ) - - def test_validate_iat(self): - id_token = jwt.encode({'alg': 'HS256'}, {'iat': 'invalid'}, 'k') - claims = jwt.decode(id_token, 'k') - self.assertRaises( - errors.InvalidClaimError, - claims.validate - ) - - def test_validate_jti(self): - id_token = jwt.encode({'alg': 'HS256'}, {'jti': 'bar'}, 'k') - claims_options = { - 'jti': { - 'validate': lambda c, o: o == 'foo' - } - } - claims = jwt.decode(id_token, 'k', claims_options=claims_options) - self.assertRaises( - errors.InvalidClaimError, - claims.validate - ) - - def test_use_jws(self): - payload = {'name': 'hi'} - private_key = read_file_path('rsa_private.pem') - pub_key = read_file_path('rsa_public.pem') - data = jwt.encode({'alg': 'RS256'}, payload, private_key) - self.assertEqual(data.count(b'.'), 2) - - claims = jwt.decode(data, pub_key) - self.assertEqual(claims['name'], 'hi') - - def test_use_jwe(self): - payload = {'name': 'hi'} - private_key = read_file_path('rsa_private.pem') - pub_key = read_file_path('rsa_public.pem') - data = jwt.encode( - {'alg': 'RSA-OAEP', 'enc': 'A256GCM'}, - payload, pub_key - ) - self.assertEqual(data.count(b'.'), 4) - - claims = jwt.decode(data, private_key) - self.assertEqual(claims['name'], 'hi') - - def test_with_ec(self): - payload = {'name': 'hi'} - private_key = read_file_path('ec_private.json') - pub_key = read_file_path('ec_public.json') - data = jwt.encode({'alg': 'ES256'}, payload, private_key) - self.assertEqual(data.count(b'.'), 2) - - claims = jwt.decode(data, pub_key) - self.assertEqual(claims['name'], 'hi') diff --git a/tests/core/test_legacy.py b/tests/core/test_legacy.py new file mode 100644 index 000000000..28b2e9646 --- /dev/null +++ b/tests/core/test_legacy.py @@ -0,0 +1,30 @@ +from joserfc.jwk import KeySet +from joserfc.jwk import OctKey + +from authlib._joserfc_helpers import import_any_key +from authlib.jose import OctKey as AuthlibOctKey + + +def test_import_legacy_oct_key(): + key1 = AuthlibOctKey.generate_key() + key2 = import_any_key(key1) + assert isinstance(key2, OctKey) + + +def test_import_from_json_str(): + data = '{"kty":"oct","k":"mGF6N2AY9YSRizMBv-DMe5NGpIP7AAcGX_w_jdiHMWc"}' + key = import_any_key(data) + assert isinstance(key, OctKey) + + +def test_import_raw_str(): + key = import_any_key("foo") + assert isinstance(key, OctKey) + + +def test_import_key_set(): + data = { + "keys": [{"kty": "oct", "k": "mGF6N2AY9YSRizMBv-DMe5NGpIP7AAcGX_w_jdiHMWc"}] + } + key = import_any_key(data) + assert isinstance(key, KeySet) diff --git a/tests/core/test_oauth2/test_rfc6749_misc.py b/tests/core/test_oauth2/test_rfc6749_misc.py index 612353bdc..819dc300a 100644 --- a/tests/core/test_oauth2/test_rfc6749_misc.py +++ b/tests/core/test_oauth2/test_rfc6749_misc.py @@ -1,89 +1,107 @@ -import unittest import base64 +import time + +import pytest + +from authlib.oauth2.rfc6749 import OAuth2Token +from authlib.oauth2.rfc6749 import errors from authlib.oauth2.rfc6749 import parameters from authlib.oauth2.rfc6749 import util -from authlib.oauth2.rfc6749 import errors -class OAuth2ParametersTest(unittest.TestCase): - def test_parse_authorization_code_response(self): - self.assertRaises( - errors.MissingCodeException, - parameters.parse_authorization_code_response, - 'https://i.b/?state=c' +def test_parse_authorization_code_response(): + with pytest.raises(errors.MissingCodeException): + parameters.parse_authorization_code_response( + "https://provider.test/?state=c", ) - self.assertRaises( - errors.MismatchingStateException, - parameters.parse_authorization_code_response, - 'https://i.b/?code=a&state=c', - 'b' + with pytest.raises(errors.MismatchingStateException): + parameters.parse_authorization_code_response( + "https://provider.test/?code=a&state=c", + "b", ) - url = 'https://i.b/?code=a&state=c' - rv = parameters.parse_authorization_code_response(url, 'c') - self.assertEqual(rv, {'code': 'a', 'state': 'c'}) + url = "https://provider.test/?code=a&state=c" + rv = parameters.parse_authorization_code_response(url, "c") + assert rv == {"code": "a", "state": "c"} - def test_parse_implicit_response(self): - self.assertRaises( - errors.MissingTokenException, - parameters.parse_implicit_response, - 'https://i.b/#a=b' - ) - self.assertRaises( - errors.MissingTokenTypeException, - parameters.parse_implicit_response, - 'https://i.b/#access_token=a' +def test_parse_implicit_response(): + with pytest.raises(errors.MissingTokenException): + parameters.parse_implicit_response( + "https://provider.test/#a=b", ) - self.assertRaises( - errors.MismatchingStateException, - parameters.parse_implicit_response, - 'https://i.b/#access_token=a&token_type=bearer&state=c', - 'abc' + with pytest.raises(errors.MissingTokenTypeException): + parameters.parse_implicit_response( + "https://provider.test/#access_token=a", ) - url = 'https://i.b/#access_token=a&token_type=bearer&state=c' - rv = parameters.parse_implicit_response(url, 'c') - self.assertEqual( - rv, - {'access_token': 'a', 'token_type': 'bearer', 'state': 'c'} + with pytest.raises(errors.MismatchingStateException): + parameters.parse_implicit_response( + "https://provider.test/#access_token=a&token_type=bearer&state=c", + "abc", ) + url = "https://provider.test/#access_token=a&token_type=bearer&state=c" + rv = parameters.parse_implicit_response(url, "c") + assert rv == {"access_token": "a", "token_type": "bearer", "state": "c"} -class OAuth2UtilTest(unittest.TestCase): - def test_list_to_scope(self): - self.assertEqual(util.list_to_scope(['a', 'b']), 'a b') - self.assertEqual(util.list_to_scope('a b'), 'a b') - self.assertIsNone(util.list_to_scope(None)) - def test_scope_to_list(self): - self.assertEqual(util.scope_to_list('a b'), ['a', 'b']) - self.assertEqual(util.scope_to_list(['a', 'b']), ['a', 'b']) - self.assertIsNone(util.scope_to_list(None)) +def test_prepare_grant_uri(): + grant_uri = parameters.prepare_grant_uri( + "https://provider.test/authorize", "dev", "code", max_age=0, resource=["a", "b"] + ) + assert ( + grant_uri + == "https://provider.test/authorize?response_type=code&client_id=dev&max_age=0&resource=a&resource=b" + ) - def test_extract_basic_authorization(self): - self.assertEqual(util.extract_basic_authorization({}), (None, None)) - self.assertEqual( - util.extract_basic_authorization({'Authorization': 'invalid'}), - (None, None) - ) - text = 'Basic invalid-base64' - self.assertEqual( - util.extract_basic_authorization({'Authorization': text}), - (None, None) - ) +def test_list_to_scope(): + assert util.list_to_scope(["a", "b"]) == "a b" + assert util.list_to_scope("a b") == "a b" + assert util.list_to_scope(None) is None - text = 'Basic {}'.format(base64.b64encode(b'a').decode()) - self.assertEqual( - util.extract_basic_authorization({'Authorization': text}), - ('a', None) - ) - text = 'Basic {}'.format(base64.b64encode(b'a:b').decode()) - self.assertEqual( - util.extract_basic_authorization({'Authorization': text}), - ('a', 'b') - ) +def test_scope_to_list(): + assert util.scope_to_list("a b") == ["a", "b"] + assert util.scope_to_list(["a", "b"]) == ["a", "b"] + assert util.scope_to_list(None) is None + + +def test_extract_basic_authorization(): + assert util.extract_basic_authorization({}) == (None, None) + assert util.extract_basic_authorization({"Authorization": "invalid"}) == ( + None, + None, + ) + + text = "Basic invalid-base64" + assert util.extract_basic_authorization({"Authorization": text}) == (None, None) + + text = "Basic {}".format(base64.b64encode(b"a").decode()) + assert util.extract_basic_authorization({"Authorization": text}) == ("a", None) + + text = "Basic {}".format(base64.b64encode(b"a:b").decode()) + assert util.extract_basic_authorization({"Authorization": text}) == ("a", "b") + + +def test_oauth2token_is_expired_with_expires_at_zero(): + """Token with expires_at=0 (epoch) should be considered expired.""" + token = OAuth2Token({"access_token": "a", "expires_at": 0}) + assert token["expires_at"] == 0 + assert token.is_expired() is True + + +def test_oauth2token_is_expired_with_expires_at_none(): + """Token with no expires_at should return None for is_expired.""" + token = OAuth2Token({"access_token": "a"}) + assert token.is_expired() is None + + +def test_oauth2token_is_expired_with_valid_token(): + """Token with future expires_at should not be expired.""" + future = int(time.time()) + 7200 + token = OAuth2Token({"access_token": "a", "expires_at": future}) + assert token.is_expired() is False diff --git a/tests/core/test_oauth2/test_rfc6750.py b/tests/core/test_oauth2/test_rfc6750.py new file mode 100644 index 000000000..4270dc76e --- /dev/null +++ b/tests/core/test_oauth2/test_rfc6750.py @@ -0,0 +1,10 @@ +from authlib.oauth2.rfc6750.errors import InvalidTokenError + + +def test_invalid_token_error_extra_attributes_in_www_authenticate(): + """Extra attributes passed to InvalidTokenError should appear as + individual key=value pairs in the WWW-Authenticate header.""" + error = InvalidTokenError(extra_attributes={"foo": "bar"}) + headers = dict(error.get_headers()) + www_authenticate = headers["WWW-Authenticate"] + assert 'foo="bar"' in www_authenticate diff --git a/tests/core/test_oauth2/test_rfc7523_client_secret.py b/tests/core/test_oauth2/test_rfc7523_client_secret.py new file mode 100644 index 000000000..b8c7e3420 --- /dev/null +++ b/tests/core/test_oauth2/test_rfc7523_client_secret.py @@ -0,0 +1,234 @@ +import time +from unittest import mock + +from joserfc import jwt +from joserfc.jwk import OctKey + +from authlib.oauth2.rfc7523 import ClientSecretJWT + + +def test_nothing_set(): + jwt_signer = ClientSecretJWT() + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims is None + assert jwt_signer.headers is None + assert jwt_signer.alg == "HS256" + + +def test_endpoint_set(): + jwt_signer = ClientSecretJWT( + token_endpoint="https://provider.test/oauth/access_token" + ) + + assert jwt_signer.token_endpoint == "https://provider.test/oauth/access_token" + assert jwt_signer.claims is None + assert jwt_signer.headers is None + assert jwt_signer.alg == "HS256" + + +def test_alg_set(): + jwt_signer = ClientSecretJWT(alg="HS512") + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims is None + assert jwt_signer.headers is None + assert jwt_signer.alg == "HS512" + + +def test_claims_set(): + jwt_signer = ClientSecretJWT(claims={"foo1": "bar1"}) + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims == {"foo1": "bar1"} + assert jwt_signer.headers is None + assert jwt_signer.alg == "HS256" + + +def test_headers_set(): + jwt_signer = ClientSecretJWT(headers={"foo1": "bar1"}) + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims is None + assert jwt_signer.headers == {"foo1": "bar1"} + assert jwt_signer.alg == "HS256" + + +def test_all_set(): + jwt_signer = ClientSecretJWT( + token_endpoint="https://provider.test/oauth/access_token", + claims={"foo1a": "bar1a"}, + headers={"foo1b": "bar1b"}, + alg="HS512", + ) + + assert jwt_signer.token_endpoint == "https://provider.test/oauth/access_token" + assert jwt_signer.claims == {"foo1a": "bar1a"} + assert jwt_signer.headers == {"foo1b": "bar1b"} + assert jwt_signer.alg == "HS512" + + +def sign_and_decode(jwt_signer, client_id, client_secret, token_endpoint): + auth = mock.MagicMock() + auth.client_id = client_id + auth.client_secret = client_secret + + pre_sign_time = int(time.time()) + + data = jwt_signer.sign(auth, token_endpoint) + decoded = jwt.decode(data, OctKey.import_key(client_secret)) + + iat = decoded.claims.pop("iat") + exp = decoded.claims.pop("exp") + jti = decoded.claims.pop("jti") + + return decoded, pre_sign_time, iat, exp, jti + + +def test_sign_nothing_set(): + jwt_signer = ClientSecretJWT() + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "client_secret_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } == decoded.claims + + assert {"alg": "HS256", "typ": "JWT"} == decoded.header + + +def test_sign_custom_jti(): + jwt_signer = ClientSecretJWT(claims={"jti": "custom_jti"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "client_secret_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert "custom_jti" == jti + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } + assert {"alg": "HS256", "typ": "JWT"} == decoded.header + + +def test_sign_with_additional_header(): + jwt_signer = ClientSecretJWT(headers={"kid": "custom_kid"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "client_secret_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } + assert {"alg": "HS256", "typ": "JWT", "kid": "custom_kid"} == decoded.header + + +def test_sign_with_additional_headers(): + jwt_signer = ClientSecretJWT( + headers={"kid": "custom_kid", "jku": "https://provider.test/oauth/jwks"} + ) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "client_secret_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } + assert { + "alg": "HS256", + "typ": "JWT", + "kid": "custom_kid", + "jku": "https://provider.test/oauth/jwks", + } == decoded.header + + +def test_sign_with_additional_claim(): + jwt_signer = ClientSecretJWT(claims={"name": "Foo"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "client_secret_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + "name": "Foo", + } + assert {"alg": "HS256", "typ": "JWT"} == decoded.header + + +def test_sign_with_additional_claims(): + jwt_signer = ClientSecretJWT(claims={"name": "Foo", "role": "bar"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "client_secret_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + "name": "Foo", + "role": "bar", + } + assert {"alg": "HS256", "typ": "JWT"} == decoded.header diff --git a/tests/core/test_oauth2/test_rfc7523_private_key.py b/tests/core/test_oauth2/test_rfc7523_private_key.py new file mode 100644 index 000000000..700d60231 --- /dev/null +++ b/tests/core/test_oauth2/test_rfc7523_private_key.py @@ -0,0 +1,231 @@ +import time +from unittest import mock + +from joserfc import jwt +from joserfc.jwk import RSAKey + +from authlib.oauth2.rfc7523 import PrivateKeyJWT +from tests.util import read_file_path + +public_key = read_file_path("rsa_public.pem") +private_key = read_file_path("rsa_private.pem") + + +def test_nothing_set(): + jwt_signer = PrivateKeyJWT() + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims is None + assert jwt_signer.headers is None + assert jwt_signer.alg == "RS256" + + +def test_endpoint_set(): + jwt_signer = PrivateKeyJWT( + token_endpoint="https://provider.test/oauth/access_token" + ) + + assert jwt_signer.token_endpoint == "https://provider.test/oauth/access_token" + assert jwt_signer.claims is None + assert jwt_signer.headers is None + assert jwt_signer.alg == "RS256" + + +def test_alg_set(): + jwt_signer = PrivateKeyJWT(alg="RS512") + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims is None + assert jwt_signer.headers is None + assert jwt_signer.alg == "RS512" + + +def test_claims_set(): + jwt_signer = PrivateKeyJWT(claims={"foo1": "bar1"}) + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims == {"foo1": "bar1"} + assert jwt_signer.headers is None + assert jwt_signer.alg == "RS256" + + +def test_headers_set(): + jwt_signer = PrivateKeyJWT(headers={"foo1": "bar1"}) + + assert jwt_signer.token_endpoint is None + assert jwt_signer.claims is None + assert jwt_signer.headers == {"foo1": "bar1"} + assert jwt_signer.alg == "RS256" + + +def test_all_set(): + jwt_signer = PrivateKeyJWT( + token_endpoint="https://provider.test/oauth/access_token", + claims={"foo1a": "bar1a"}, + headers={"foo1b": "bar1b"}, + alg="RS512", + ) + + assert jwt_signer.token_endpoint == "https://provider.test/oauth/access_token" + assert jwt_signer.claims == {"foo1a": "bar1a"} + assert jwt_signer.headers == {"foo1b": "bar1b"} + assert jwt_signer.alg == "RS512" + + +def sign_and_decode(jwt_signer, client_id, token_endpoint): + auth = mock.MagicMock() + auth.client_id = client_id + auth.client_secret = private_key + + pre_sign_time = int(time.time()) + + data = jwt_signer.sign(auth, token_endpoint) + decoded = jwt.decode(data, RSAKey.import_key(public_key)) + + iat = decoded.claims.pop("iat") + exp = decoded.claims.pop("exp") + jti = decoded.claims.pop("jti") + + return decoded, pre_sign_time, iat, exp, jti + + +def test_sign_nothing_set(): + jwt_signer = PrivateKeyJWT() + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } == decoded.claims + assert {"alg": "RS256", "typ": "JWT"} == decoded.header + + +def test_sign_custom_jti(): + jwt_signer = PrivateKeyJWT(claims={"jti": "custom_jti"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert "custom_jti" == jti + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } + assert {"alg": "RS256", "typ": "JWT"} == decoded.header + + +def test_sign_with_additional_header(): + jwt_signer = PrivateKeyJWT(headers={"kid": "custom_kid"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } + assert {"alg": "RS256", "typ": "JWT", "kid": "custom_kid"} == decoded.header + + +def test_sign_with_additional_headers(): + jwt_signer = PrivateKeyJWT( + headers={"kid": "custom_kid", "jku": "https://provider.test/oauth/jwks"} + ) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + } + assert { + "alg": "RS256", + "typ": "JWT", + "kid": "custom_kid", + "jku": "https://provider.test/oauth/jwks", + } == decoded.header + + +def test_sign_with_additional_claim(): + jwt_signer = PrivateKeyJWT(claims={"name": "Foo"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + "name": "Foo", + } + assert {"alg": "RS256", "typ": "JWT"} == decoded.header + + +def test_sign_with_additional_claims(): + jwt_signer = PrivateKeyJWT(claims={"name": "Foo", "role": "bar"}) + + decoded, pre_sign_time, iat, exp, jti = sign_and_decode( + jwt_signer, + "client_id_1", + "https://provider.test/oauth/access_token", + ) + + assert iat >= pre_sign_time + assert exp >= iat + 3600 + assert exp <= iat + 3600 + 2 + assert jti is not None + + assert decoded.claims == { + "iss": "client_id_1", + "aud": "https://provider.test/oauth/access_token", + "sub": "client_id_1", + "name": "Foo", + "role": "bar", + } + assert {"alg": "RS256", "typ": "JWT"} == decoded.header diff --git a/tests/core/test_oauth2/test_rfc7523_validator.py b/tests/core/test_oauth2/test_rfc7523_validator.py new file mode 100644 index 000000000..fcf5314b4 --- /dev/null +++ b/tests/core/test_oauth2/test_rfc7523_validator.py @@ -0,0 +1,61 @@ +import time + +import pytest +from joserfc import jws +from joserfc import jwt +from joserfc.jwk import OctKey + +from authlib.oauth2.rfc6750.errors import InvalidTokenError +from authlib.oauth2.rfc7523 import JWTBearerTokenValidator + + +@pytest.fixture +def oct_key(): + return OctKey.generate_key() + + +def test_invalid_token_string(oct_key): + validator = JWTBearerTokenValidator(oct_key) + token_string = jws.serialize_compact({"alg": "HS256"}, "text", oct_key) + token = validator.authenticate_token(token_string) + assert token is None + + +def test_missint_claims(oct_key): + validator = JWTBearerTokenValidator(oct_key) + token_string = jwt.encode({"alg": "HS256"}, {}, oct_key) + token = validator.authenticate_token(token_string) + assert token is None + + +def test_authenticate_token(oct_key): + validator = JWTBearerTokenValidator(oct_key, issuer="foo") + claims = { + "iss": "bar", + "exp": int(time.time() + 3600), + "client_id": "client-id", + "grant_type": "client_credentials", + } + token_string = jwt.encode({"alg": "HS256"}, claims, oct_key) + token = validator.authenticate_token(token_string) + assert token is None + + token_string = jwt.encode({"alg": "HS256"}, {**claims, "iss": "foo"}, oct_key) + token = validator.authenticate_token(token_string) + assert token is not None + + +def test_expired_token(oct_key): + validator = JWTBearerTokenValidator(oct_key) + claims = { + "exp": time.time() + 0.01, + "client_id": "client-id", + "grant_type": "client_credentials", + } + token_string = jwt.encode({"alg": "HS256"}, claims, oct_key) + token = validator.authenticate_token(token_string) + assert token is not None + + time.sleep(0.1) + with pytest.raises(InvalidTokenError): + validator.validate_token(token, [], None) diff --git a/tests/core/test_oauth2/test_rfc7591.py b/tests/core/test_oauth2/test_rfc7591.py index 175a26858..f3c5bcf0a 100644 --- a/tests/core/test_oauth2/test_rfc7591.py +++ b/tests/core/test_oauth2/test_rfc7591.py @@ -1,29 +1,40 @@ -from unittest import TestCase +import pytest +from joserfc.errors import InvalidClaimError + from authlib.oauth2.rfc7591 import ClientMetadataClaims -from authlib.jose.errors import InvalidClaimError -class ClientMetadataClaimsTest(TestCase): - def test_validate_redirect_uris(self): - claims = ClientMetadataClaims({'redirect_uris': ['foo']}, {}) - self.assertRaises(InvalidClaimError, claims.validate) +def test_validate_redirect_uris(): + claims = ClientMetadataClaims({"redirect_uris": ["foo"]}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + + +def test_validate_client_uri(): + claims = ClientMetadataClaims({"client_uri": "foo"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + + +def test_validate_logo_uri(): + claims = ClientMetadataClaims({"logo_uri": "foo"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + - def test_validate_client_uri(self): - claims = ClientMetadataClaims({'client_uri': 'foo'}, {}) - self.assertRaises(InvalidClaimError, claims.validate) +def test_validate_tos_uri(): + claims = ClientMetadataClaims({"tos_uri": "foo"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() - def test_validate_logo_uri(self): - claims = ClientMetadataClaims({'logo_uri': 'foo'}, {}) - self.assertRaises(InvalidClaimError, claims.validate) - def test_validate_tos_uri(self): - claims = ClientMetadataClaims({'tos_uri': 'foo'}, {}) - self.assertRaises(InvalidClaimError, claims.validate) +def test_validate_policy_uri(): + claims = ClientMetadataClaims({"policy_uri": "foo"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() - def test_validate_policy_uri(self): - claims = ClientMetadataClaims({'policy_uri': 'foo'}, {}) - self.assertRaises(InvalidClaimError, claims.validate) - def test_validate_jwks_uri(self): - claims = ClientMetadataClaims({'jwks_uri': 'foo'}, {}) - self.assertRaises(InvalidClaimError, claims.validate) +def test_validate_jwks_uri(): + claims = ClientMetadataClaims({"jwks_uri": "foo"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() diff --git a/tests/core/test_oauth2/test_rfc7662.py b/tests/core/test_oauth2/test_rfc7662.py index 80211bb95..1b4fee33d 100644 --- a/tests/core/test_oauth2/test_rfc7662.py +++ b/tests/core/test_oauth2/test_rfc7662.py @@ -1,55 +1,61 @@ -import unittest +import pytest + from authlib.oauth2.rfc7662 import IntrospectionToken -class IntrospectionTokenTest(unittest.TestCase): - def test_client_id(self): - token = IntrospectionToken() - self.assertIsNone(token.client_id) - self.assertIsNone(token.get_client_id()) - - token = IntrospectionToken({'client_id': 'foo'}) - self.assertEqual(token.client_id, 'foo') - self.assertEqual(token.get_client_id(), 'foo') - - def test_scope(self): - token = IntrospectionToken() - self.assertIsNone(token.scope) - self.assertIsNone(token.get_scope()) - - token = IntrospectionToken({'scope': 'foo'}) - self.assertEqual(token.scope, 'foo') - self.assertEqual(token.get_scope(), 'foo') - - def test_expires_in(self): - token = IntrospectionToken() - self.assertEqual(token.get_expires_in(), 0) - - def test_expires_at(self): - token = IntrospectionToken() - self.assertIsNone(token.exp) - self.assertEqual(token.get_expires_at(), 0) - - token = IntrospectionToken({'exp': 3600}) - self.assertEqual(token.exp, 3600) - self.assertEqual(token.get_expires_at(), 3600) - - def test_all_attributes(self): - # https://tools.ietf.org/html/rfc7662#section-2.2 - token = IntrospectionToken() - self.assertIsNone(token.active) - self.assertIsNone(token.scope) - self.assertIsNone(token.client_id) - self.assertIsNone(token.username) - self.assertIsNone(token.token_type) - self.assertIsNone(token.exp) - self.assertIsNone(token.iat) - self.assertIsNone(token.nbf) - self.assertIsNone(token.sub) - self.assertIsNone(token.aud) - self.assertIsNone(token.iss) - self.assertIsNone(token.jti) - - def test_invalid_attr(self): - token = IntrospectionToken() - self.assertRaises(AttributeError, lambda: token.invalid) +def test_client_id(): + token = IntrospectionToken() + assert token.client_id is None + assert token.get_client_id() is None + + token = IntrospectionToken({"client_id": "foo"}) + assert token.client_id == "foo" + assert token.get_client_id() == "foo" + + +def test_scope(): + token = IntrospectionToken() + assert token.scope is None + assert token.get_scope() is None + + token = IntrospectionToken({"scope": "foo"}) + assert token.scope == "foo" + assert token.get_scope() == "foo" + + +def test_expires_in(): + token = IntrospectionToken() + assert token.get_expires_in() == 0 + + +def test_expires_at(): + token = IntrospectionToken() + assert token.exp is None + assert token.get_expires_at() == 0 + + token = IntrospectionToken({"exp": 3600}) + assert token.exp == 3600 + assert token.get_expires_at() == 3600 + + +def test_all_attributes(): + # https://tools.ietf.org/html/rfc7662#section-2.2 + token = IntrospectionToken() + assert token.active is None + assert token.scope is None + assert token.client_id is None + assert token.username is None + assert token.token_type is None + assert token.exp is None + assert token.iat is None + assert token.nbf is None + assert token.sub is None + assert token.aud is None + assert token.iss is None + assert token.jti is None + + +def test_invalid_attr(): + token = IntrospectionToken() + with pytest.raises(AttributeError): + token.invalid # noqa:B018 diff --git a/tests/core/test_oauth2/test_rfc8414.py b/tests/core/test_oauth2/test_rfc8414.py index 1b8b98844..d266f2dd0 100644 --- a/tests/core/test_oauth2/test_rfc8414.py +++ b/tests/core/test_oauth2/test_rfc8414.py @@ -1,498 +1,462 @@ -import unittest -from authlib.oauth2.rfc8414 import get_well_known_url +import pytest + +from authlib.oauth2 import rfc9101 from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.oauth2.rfc8414 import get_well_known_url + +WELL_KNOWN_URL = "/.well-known/oauth-authorization-server" + + +def test_well_know_no_suffix_issuer(): + assert get_well_known_url("https://provider.test") == WELL_KNOWN_URL + assert get_well_known_url("https://provider.test/") == WELL_KNOWN_URL + + +def test_well_know_with_suffix_issuer(): + assert ( + get_well_known_url("https://provider.test/issuer1") + == WELL_KNOWN_URL + "/issuer1" + ) + assert ( + get_well_known_url("https://provider.test/a/b/c") == WELL_KNOWN_URL + "/a/b/c" + ) + + +def test_well_know_with_external(): + assert ( + get_well_known_url("https://provider.test", external=True) + == "https://provider.test" + WELL_KNOWN_URL + ) + +def test_well_know_with_changed_suffix(): + url = get_well_known_url("https://provider.test", suffix="openid-configuration") + assert url == "/.well-known/openid-configuration" + url = get_well_known_url( + "https://provider.test", external=True, suffix="openid-configuration" + ) + assert url == "https://provider.test/.well-known/openid-configuration" -WELL_KNOWN_URL = '/.well-known/oauth-authorization-server' - - -class WellKnownTest(unittest.TestCase): - def test_no_suffix_issuer(self): - self.assertEqual( - get_well_known_url('https://authlib.org'), - WELL_KNOWN_URL - ) - self.assertEqual( - get_well_known_url('https://authlib.org/'), - WELL_KNOWN_URL - ) - - def test_with_suffix_issuer(self): - self.assertEqual( - get_well_known_url('https://authlib.org/issuer1'), - WELL_KNOWN_URL + '/issuer1' - ) - self.assertEqual( - get_well_known_url('https://authlib.org/a/b/c'), - WELL_KNOWN_URL + '/a/b/c' - ) - - def test_with_external(self): - self.assertEqual( - get_well_known_url('https://authlib.org', external=True), - 'https://authlib.org' + WELL_KNOWN_URL - ) - - def test_with_changed_suffix(self): - url = get_well_known_url( - 'https://authlib.org', - suffix='openid-configuration') - self.assertEqual(url, '/.well-known/openid-configuration') - url = get_well_known_url( - 'https://authlib.org', - external=True, - suffix='openid-configuration' - ) - self.assertEqual(url, 'https://authlib.org/.well-known/openid-configuration') - - -class AuthorizationServerMetadataTest(unittest.TestCase): - def test_validate_issuer(self): - #: missing - metadata = AuthorizationServerMetadata({}) - with self.assertRaises(ValueError) as cm: - metadata.validate() - self.assertEqual('"issuer" is required', str(cm.exception)) - - #: https - metadata = AuthorizationServerMetadata({ - 'issuer': 'http://authlib.org/' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_issuer() - self.assertIn('https', str(cm.exception)) - - #: query - metadata = AuthorizationServerMetadata({ - 'issuer': 'https://authlib.org/?a=b' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_issuer() - self.assertIn('query', str(cm.exception)) - - #: fragment - metadata = AuthorizationServerMetadata({ - 'issuer': 'https://authlib.org/#a=b' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_issuer() - self.assertIn('fragment', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'issuer': 'https://authlib.org/' - }) + +def test_validate_issuer(): + #: missing + metadata = AuthorizationServerMetadata({}) + with pytest.raises(ValueError, match='"issuer" is required'): + metadata.validate() + + #: https + metadata = AuthorizationServerMetadata({"issuer": "http://provider.test/"}) + with pytest.raises(ValueError, match="https"): + metadata.validate_issuer() + + #: query + metadata = AuthorizationServerMetadata({"issuer": "https://provider.test/?a=b"}) + with pytest.raises(ValueError, match="query"): + metadata.validate_issuer() + + #: fragment + metadata = AuthorizationServerMetadata({"issuer": "https://provider.test/#a=b"}) + with pytest.raises(ValueError, match="fragment"): metadata.validate_issuer() - def test_validate_authorization_endpoint(self): - # https - metadata = AuthorizationServerMetadata({ - 'authorization_endpoint': 'http://authlib.org/' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_authorization_endpoint() - self.assertIn('https', str(cm.exception)) - - # valid https - metadata = AuthorizationServerMetadata({ - 'authorization_endpoint': 'https://authlib.org/' - }) + metadata = AuthorizationServerMetadata({"issuer": "https://provider.test/"}) + metadata.validate_issuer() + + +def test_validate_authorization_endpoint(): + # https + metadata = AuthorizationServerMetadata( + {"authorization_endpoint": "http://provider.test/"} + ) + with pytest.raises(ValueError, match="https"): metadata.validate_authorization_endpoint() - # missing - metadata = AuthorizationServerMetadata() - with self.assertRaises(ValueError) as cm: - metadata.validate_authorization_endpoint() - self.assertIn('required', str(cm.exception)) + # valid https + metadata = AuthorizationServerMetadata( + {"authorization_endpoint": "https://provider.test/"} + ) + metadata.validate_authorization_endpoint() - # valid missing - metadata = AuthorizationServerMetadata({ - 'grant_types_supported': ['password'] - }) + # missing + metadata = AuthorizationServerMetadata() + with pytest.raises(ValueError, match="required"): metadata.validate_authorization_endpoint() - def test_validate_token_endpoint(self): - # implicit - metadata = AuthorizationServerMetadata({ - 'grant_types_supported': ['implicit'] - }) + # valid missing + metadata = AuthorizationServerMetadata({"grant_types_supported": ["password"]}) + metadata.validate_authorization_endpoint() + + +def test_validate_token_endpoint(): + # implicit + metadata = AuthorizationServerMetadata({"grant_types_supported": ["implicit"]}) + metadata.validate_token_endpoint() + + # missing + metadata = AuthorizationServerMetadata() + with pytest.raises(ValueError, match="required"): metadata.validate_token_endpoint() - # missing - metadata = AuthorizationServerMetadata() - with self.assertRaises(ValueError) as cm: - metadata.validate_token_endpoint() - self.assertIn('required', str(cm.exception)) - - # https - metadata = AuthorizationServerMetadata({ - 'token_endpoint': 'http://authlib.org/' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_token_endpoint() - self.assertIn('https', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'token_endpoint': 'https://authlib.org/' - }) + # https + metadata = AuthorizationServerMetadata({"token_endpoint": "http://provider.test/"}) + with pytest.raises(ValueError, match="https"): metadata.validate_token_endpoint() - def test_validate_jwks_uri(self): - # can missing - metadata = AuthorizationServerMetadata() - metadata.validate_jwks_uri() + # valid + metadata = AuthorizationServerMetadata({"token_endpoint": "https://provider.test/"}) + metadata.validate_token_endpoint() + - metadata = AuthorizationServerMetadata({ - 'jwks_uri': 'http://authlib.org/jwks.json' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_jwks_uri() - self.assertIn('https', str(cm.exception)) +def test_validate_jwks_uri(): + # can missing + metadata = AuthorizationServerMetadata() + metadata.validate_jwks_uri() - metadata = AuthorizationServerMetadata({ - 'jwks_uri': 'https://authlib.org/jwks.json' - }) + metadata = AuthorizationServerMetadata( + {"jwks_uri": "http://provider.test/jwks.json"} + ) + with pytest.raises(ValueError, match="https"): metadata.validate_jwks_uri() - def test_validate_registration_endpoint(self): - metadata = AuthorizationServerMetadata() - metadata.validate_registration_endpoint() + metadata = AuthorizationServerMetadata( + {"jwks_uri": "https://provider.test/jwks.json"} + ) + metadata.validate_jwks_uri() - metadata = AuthorizationServerMetadata({ - 'registration_endpoint': 'http://authlib.org/' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_registration_endpoint() - self.assertIn('https', str(cm.exception)) - metadata = AuthorizationServerMetadata({ - 'registration_endpoint': 'https://authlib.org/' - }) +def test_validate_registration_endpoint(): + metadata = AuthorizationServerMetadata() + metadata.validate_registration_endpoint() + + metadata = AuthorizationServerMetadata( + {"registration_endpoint": "http://provider.test/"} + ) + with pytest.raises(ValueError, match="https"): metadata.validate_registration_endpoint() - def test_validate_scopes_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_scopes_supported() + metadata = AuthorizationServerMetadata( + {"registration_endpoint": "https://provider.test/"} + ) + metadata.validate_registration_endpoint() + - # not array - metadata = AuthorizationServerMetadata({ - 'scopes_supported': 'foo' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_scopes_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'scopes_supported': ['foo'] - }) +def test_validate_scopes_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_scopes_supported() + + # not array + metadata = AuthorizationServerMetadata({"scopes_supported": "foo"}) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_scopes_supported() - def test_validate_response_types_supported(self): - # missing - metadata = AuthorizationServerMetadata() - with self.assertRaises(ValueError) as cm: - metadata.validate_response_types_supported() - self.assertIn('required', str(cm.exception)) - - # not array - metadata = AuthorizationServerMetadata({ - 'response_types_supported': 'code' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_response_types_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'response_types_supported': ['code'] - }) + # valid + metadata = AuthorizationServerMetadata({"scopes_supported": ["foo"]}) + metadata.validate_scopes_supported() + + +def test_validate_response_types_supported(): + # missing + metadata = AuthorizationServerMetadata() + with pytest.raises(ValueError, match="required"): metadata.validate_response_types_supported() - def test_validate_response_modes_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_response_modes_supported() + # not array + metadata = AuthorizationServerMetadata({"response_types_supported": "code"}) + with pytest.raises(ValueError, match="JSON array"): + metadata.validate_response_types_supported() - # not array - metadata = AuthorizationServerMetadata({ - 'response_modes_supported': 'query' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_response_modes_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'response_modes_supported': ['query'] - }) + # valid + metadata = AuthorizationServerMetadata({"response_types_supported": ["code"]}) + metadata.validate_response_types_supported() + + +def test_validate_response_modes_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_response_modes_supported() + + # not array + metadata = AuthorizationServerMetadata({"response_modes_supported": "query"}) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_response_modes_supported() - def test_validate_grant_types_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_grant_types_supported() + # valid + metadata = AuthorizationServerMetadata({"response_modes_supported": ["query"]}) + metadata.validate_response_modes_supported() + - # not array - metadata = AuthorizationServerMetadata({ - 'grant_types_supported': 'password' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_grant_types_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'grant_types_supported': ['password'] - }) +def test_validate_grant_types_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_grant_types_supported() + + # not array + metadata = AuthorizationServerMetadata({"grant_types_supported": "password"}) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_grant_types_supported() - def test_validate_token_endpoint_auth_methods_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_token_endpoint_auth_methods_supported() + # valid + metadata = AuthorizationServerMetadata({"grant_types_supported": ["password"]}) + metadata.validate_grant_types_supported() + - # not array - metadata = AuthorizationServerMetadata({ - 'token_endpoint_auth_methods_supported': 'client_secret_basic' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_token_endpoint_auth_methods_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'token_endpoint_auth_methods_supported': ['client_secret_basic'] - }) +def test_validate_token_endpoint_auth_methods_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_token_endpoint_auth_methods_supported() + + # not array + metadata = AuthorizationServerMetadata( + {"token_endpoint_auth_methods_supported": "client_secret_basic"} + ) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_token_endpoint_auth_methods_supported() - def test_validate_token_endpoint_auth_signing_alg_values_supported(self): - metadata = AuthorizationServerMetadata() + # valid + metadata = AuthorizationServerMetadata( + {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + ) + metadata.validate_token_endpoint_auth_methods_supported() + + +def test_validate_token_endpoint_auth_signing_alg_values_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_token_endpoint_auth_signing_alg_values_supported() + + metadata = AuthorizationServerMetadata( + {"token_endpoint_auth_methods_supported": ["client_secret_jwt"]} + ) + with pytest.raises(ValueError, match="required"): metadata.validate_token_endpoint_auth_signing_alg_values_supported() - metadata = AuthorizationServerMetadata({ - 'token_endpoint_auth_methods_supported': ['client_secret_jwt'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_token_endpoint_auth_signing_alg_values_supported() - self.assertIn('required', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'token_endpoint_auth_signing_alg_values_supported': 'RS256' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_token_endpoint_auth_signing_alg_values_supported() - self.assertIn('JSON array', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'token_endpoint_auth_methods_supported': ['client_secret_jwt'], - 'token_endpoint_auth_signing_alg_values_supported': ['RS256', 'none'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_token_endpoint_auth_signing_alg_values_supported() - self.assertIn('none', str(cm.exception)) - - def test_validate_service_documentation(self): - metadata = AuthorizationServerMetadata() - metadata.validate_service_documentation() + metadata = AuthorizationServerMetadata( + {"token_endpoint_auth_signing_alg_values_supported": "RS256"} + ) + with pytest.raises(ValueError, match="JSON array"): + metadata.validate_token_endpoint_auth_signing_alg_values_supported() + + metadata = AuthorizationServerMetadata( + { + "token_endpoint_auth_methods_supported": ["client_secret_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["RS256", "none"], + } + ) + with pytest.raises(ValueError, match="none"): + metadata.validate_token_endpoint_auth_signing_alg_values_supported() - metadata = AuthorizationServerMetadata({ - 'service_documentation': 'invalid' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_service_documentation() - self.assertIn('MUST be a URL', str(cm.exception)) - metadata = AuthorizationServerMetadata({ - 'service_documentation': 'https://authlib.org/' - }) +def test_validate_service_documentation(): + metadata = AuthorizationServerMetadata() + metadata.validate_service_documentation() + + metadata = AuthorizationServerMetadata({"service_documentation": "invalid"}) + with pytest.raises(ValueError, match="MUST be a URL"): metadata.validate_service_documentation() - def test_validate_ui_locales_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_ui_locales_supported() + metadata = AuthorizationServerMetadata( + {"service_documentation": "https://provider.test/"} + ) + metadata.validate_service_documentation() + - # not array - metadata = AuthorizationServerMetadata({ - 'ui_locales_supported': 'en' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_ui_locales_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'ui_locales_supported': ['en'] - }) +def test_validate_ui_locales_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_ui_locales_supported() + + # not array + metadata = AuthorizationServerMetadata({"ui_locales_supported": "en"}) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_ui_locales_supported() - def test_validate_op_policy_uri(self): - metadata = AuthorizationServerMetadata() - metadata.validate_op_policy_uri() + # valid + metadata = AuthorizationServerMetadata({"ui_locales_supported": ["en"]}) + metadata.validate_ui_locales_supported() - metadata = AuthorizationServerMetadata({ - 'op_policy_uri': 'invalid' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_op_policy_uri() - self.assertIn('MUST be a URL', str(cm.exception)) - metadata = AuthorizationServerMetadata({ - 'op_policy_uri': 'https://authlib.org/' - }) +def test_validate_op_policy_uri(): + metadata = AuthorizationServerMetadata() + metadata.validate_op_policy_uri() + + metadata = AuthorizationServerMetadata({"op_policy_uri": "invalid"}) + with pytest.raises(ValueError, match="MUST be a URL"): metadata.validate_op_policy_uri() - def test_validate_op_tos_uri(self): - metadata = AuthorizationServerMetadata() - metadata.validate_op_tos_uri() + metadata = AuthorizationServerMetadata({"op_policy_uri": "https://provider.test/"}) + metadata.validate_op_policy_uri() + - metadata = AuthorizationServerMetadata({ - 'op_tos_uri': 'invalid' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_op_tos_uri() - self.assertIn('MUST be a URL', str(cm.exception)) +def test_validate_op_tos_uri(): + metadata = AuthorizationServerMetadata() + metadata.validate_op_tos_uri() - metadata = AuthorizationServerMetadata({ - 'op_tos_uri': 'https://authlib.org/' - }) + metadata = AuthorizationServerMetadata({"op_tos_uri": "invalid"}) + with pytest.raises(ValueError, match="MUST be a URL"): metadata.validate_op_tos_uri() - def test_validate_revocation_endpoint(self): - metadata = AuthorizationServerMetadata() - metadata.validate_revocation_endpoint() + metadata = AuthorizationServerMetadata({"op_tos_uri": "https://provider.test/"}) + metadata.validate_op_tos_uri() + - # https - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint': 'http://authlib.org/' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_revocation_endpoint() - self.assertIn('https', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint': 'https://authlib.org/' - }) +def test_validate_revocation_endpoint(): + metadata = AuthorizationServerMetadata() + metadata.validate_revocation_endpoint() + + # https + metadata = AuthorizationServerMetadata( + {"revocation_endpoint": "http://provider.test/"} + ) + with pytest.raises(ValueError, match="https"): metadata.validate_revocation_endpoint() - def test_validate_revocation_endpoint_auth_methods_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_revocation_endpoint_auth_methods_supported() + # valid + metadata = AuthorizationServerMetadata( + {"revocation_endpoint": "https://provider.test/"} + ) + metadata.validate_revocation_endpoint() + - # not array - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint_auth_methods_supported': 'client_secret_basic' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_revocation_endpoint_auth_methods_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint_auth_methods_supported': ['client_secret_basic'] - }) +def test_validate_revocation_endpoint_auth_methods_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_revocation_endpoint_auth_methods_supported() + + # not array + metadata = AuthorizationServerMetadata( + {"revocation_endpoint_auth_methods_supported": "client_secret_basic"} + ) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_revocation_endpoint_auth_methods_supported() - def test_validate_revocation_endpoint_auth_signing_alg_values_supported(self): - metadata = AuthorizationServerMetadata() + # valid + metadata = AuthorizationServerMetadata( + {"revocation_endpoint_auth_methods_supported": ["client_secret_basic"]} + ) + metadata.validate_revocation_endpoint_auth_methods_supported() + + +def test_validate_revocation_endpoint_auth_signing_alg_values_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() + + metadata = AuthorizationServerMetadata( + {"revocation_endpoint_auth_methods_supported": ["client_secret_jwt"]} + ) + with pytest.raises(ValueError, match="required"): metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint_auth_methods_supported': ['client_secret_jwt'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() - self.assertIn('required', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint_auth_signing_alg_values_supported': 'RS256' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() - self.assertIn('JSON array', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'revocation_endpoint_auth_methods_supported': ['client_secret_jwt'], - 'revocation_endpoint_auth_signing_alg_values_supported': ['RS256', 'none'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() - self.assertIn('none', str(cm.exception)) - - def test_validate_introspection_endpoint(self): - metadata = AuthorizationServerMetadata() - metadata.validate_introspection_endpoint() + metadata = AuthorizationServerMetadata( + {"revocation_endpoint_auth_signing_alg_values_supported": "RS256"} + ) + with pytest.raises(ValueError, match="JSON array"): + metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() + + metadata = AuthorizationServerMetadata( + { + "revocation_endpoint_auth_methods_supported": ["client_secret_jwt"], + "revocation_endpoint_auth_signing_alg_values_supported": [ + "RS256", + "none", + ], + } + ) + with pytest.raises(ValueError, match="none"): + metadata.validate_revocation_endpoint_auth_signing_alg_values_supported() - # https - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint': 'http://authlib.org/' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_introspection_endpoint() - self.assertIn('https', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint': 'https://authlib.org/' - }) + +def test_validate_introspection_endpoint(): + metadata = AuthorizationServerMetadata() + metadata.validate_introspection_endpoint() + + # https + metadata = AuthorizationServerMetadata( + {"introspection_endpoint": "http://provider.test/"} + ) + with pytest.raises(ValueError, match="https"): metadata.validate_introspection_endpoint() - def test_validate_introspection_endpoint_auth_methods_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_introspection_endpoint_auth_methods_supported() + # valid + metadata = AuthorizationServerMetadata( + {"introspection_endpoint": "https://provider.test/"} + ) + metadata.validate_introspection_endpoint() + + +def test_validate_introspection_endpoint_auth_methods_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_introspection_endpoint_auth_methods_supported() - # not array - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint_auth_methods_supported': 'client_secret_basic' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_introspection_endpoint_auth_methods_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint_auth_methods_supported': ['client_secret_basic'] - }) + # not array + metadata = AuthorizationServerMetadata( + {"introspection_endpoint_auth_methods_supported": "client_secret_basic"} + ) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_introspection_endpoint_auth_methods_supported() - def test_validate_introspection_endpoint_auth_signing_alg_values_supported(self): - metadata = AuthorizationServerMetadata() + # valid + metadata = AuthorizationServerMetadata( + {"introspection_endpoint_auth_methods_supported": ["client_secret_basic"]} + ) + metadata.validate_introspection_endpoint_auth_methods_supported() + + +def test_validate_introspection_endpoint_auth_signing_alg_values_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() + + metadata = AuthorizationServerMetadata( + {"introspection_endpoint_auth_methods_supported": ["client_secret_jwt"]} + ) + with pytest.raises(ValueError, match="required"): metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint_auth_methods_supported': ['client_secret_jwt'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() - self.assertIn('required', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint_auth_signing_alg_values_supported': 'RS256' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() - self.assertIn('JSON array', str(cm.exception)) - - metadata = AuthorizationServerMetadata({ - 'introspection_endpoint_auth_methods_supported': ['client_secret_jwt'], - 'introspection_endpoint_auth_signing_alg_values_supported': ['RS256', 'none'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() - self.assertIn('none', str(cm.exception)) - - def test_validate_code_challenge_methods_supported(self): - metadata = AuthorizationServerMetadata() - metadata.validate_code_challenge_methods_supported() + metadata = AuthorizationServerMetadata( + {"introspection_endpoint_auth_signing_alg_values_supported": "RS256"} + ) + with pytest.raises(ValueError, match="JSON array"): + metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() - # not array - metadata = AuthorizationServerMetadata({ - 'code_challenge_methods_supported': 'S256' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_code_challenge_methods_supported() - self.assertIn('JSON array', str(cm.exception)) - - # valid - metadata = AuthorizationServerMetadata({ - 'code_challenge_methods_supported': ['S256'] - }) + metadata = AuthorizationServerMetadata( + { + "introspection_endpoint_auth_methods_supported": ["client_secret_jwt"], + "introspection_endpoint_auth_signing_alg_values_supported": [ + "RS256", + "none", + ], + } + ) + with pytest.raises(ValueError, match="none"): + metadata.validate_introspection_endpoint_auth_signing_alg_values_supported() + + +def test_validate_code_challenge_methods_supported(): + metadata = AuthorizationServerMetadata() + metadata.validate_code_challenge_methods_supported() + + # not array + metadata = AuthorizationServerMetadata({"code_challenge_methods_supported": "S256"}) + with pytest.raises(ValueError, match="JSON array"): metadata.validate_code_challenge_methods_supported() + + # valid + metadata = AuthorizationServerMetadata( + {"code_challenge_methods_supported": ["S256"]} + ) + metadata.validate_code_challenge_methods_supported() + + +def test_validate_with_metadata_classes(): + """Test that validate() can compose metadata extension classes.""" + + base_metadata = { + "issuer": "https://provider.test", + "authorization_endpoint": "https://provider.test/auth", + "token_endpoint": "https://provider.test/token", + "response_types_supported": ["code"], + } + + metadata = AuthorizationServerMetadata( + {**base_metadata, "require_signed_request_object": True} + ) + metadata.validate(metadata_classes=[rfc9101.AuthorizationServerMetadata]) + + metadata = AuthorizationServerMetadata( + {**base_metadata, "require_signed_request_object": "invalid"} + ) + with pytest.raises(ValueError, match="boolean"): + metadata.validate(metadata_classes=[rfc9101.AuthorizationServerMetadata]) diff --git a/tests/core/test_oauth2/test_rfc9207.py b/tests/core/test_oauth2/test_rfc9207.py new file mode 100644 index 000000000..6364481e7 --- /dev/null +++ b/tests/core/test_oauth2/test_rfc9207.py @@ -0,0 +1,45 @@ +import pytest + +from authlib.oauth2 import rfc8414 +from authlib.oauth2 import rfc9207 + + +def test_validate_authorization_response_iss_parameter_supported(): + metadata = rfc9207.AuthorizationServerMetadata() + metadata.validate_authorization_response_iss_parameter_supported() + + metadata = rfc9207.AuthorizationServerMetadata( + {"authorization_response_iss_parameter_supported": True} + ) + metadata.validate_authorization_response_iss_parameter_supported() + + metadata = rfc9207.AuthorizationServerMetadata( + {"authorization_response_iss_parameter_supported": False} + ) + metadata.validate_authorization_response_iss_parameter_supported() + + metadata = rfc9207.AuthorizationServerMetadata( + {"authorization_response_iss_parameter_supported": "invalid"} + ) + with pytest.raises(ValueError, match="boolean"): + metadata.validate_authorization_response_iss_parameter_supported() + + +def test_metadata_classes_composition(): + base_metadata = { + "issuer": "https://provider.test", + "authorization_endpoint": "https://provider.test/auth", + "token_endpoint": "https://provider.test/token", + "response_types_supported": ["code"], + } + + metadata = rfc8414.AuthorizationServerMetadata( + {**base_metadata, "authorization_response_iss_parameter_supported": True} + ) + metadata.validate(metadata_classes=[rfc9207.AuthorizationServerMetadata]) + + metadata = rfc8414.AuthorizationServerMetadata( + {**base_metadata, "authorization_response_iss_parameter_supported": "invalid"} + ) + with pytest.raises(ValueError, match="boolean"): + metadata.validate(metadata_classes=[rfc9207.AuthorizationServerMetadata]) diff --git a/tests/core/test_oidc/test_core.py b/tests/core/test_oidc/test_core.py index 92e76bc36..9e29a8eb2 100644 --- a/tests/core/test_oidc/test_core.py +++ b/tests/core/test_oidc/test_core.py @@ -1,145 +1,174 @@ -import unittest -from authlib.jose.errors import MissingClaimError, InvalidClaimError -from authlib.oidc.core import CodeIDToken, ImplicitIDToken, HybridIDToken -from authlib.oidc.core import UserInfo, get_claim_cls_by_response_type - - -class IDTokenTest(unittest.TestCase): - def test_essential_claims(self): - claims = CodeIDToken({}, {}) - self.assertRaises(MissingClaimError, claims.validate) - claims = CodeIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100 - }, {}) +import pytest +from joserfc.errors import InvalidClaimError +from joserfc.errors import MissingClaimError + +from authlib.oidc.core import CodeIDToken +from authlib.oidc.core import HybridIDToken +from authlib.oidc.core import ImplicitIDToken +from authlib.oidc.core import UserInfo +from authlib.oidc.core import get_claim_cls_by_response_type + + +def test_essential_claims(): + claims = CodeIDToken({}, {}) + with pytest.raises(MissingClaimError): + claims.validate() + claims = CodeIDToken( + {"iss": "1", "sub": "1", "aud": "1", "exp": 10000, "iat": 100}, {} + ) + claims.validate(1000) + + +def test_validate_auth_time(): + claims = CodeIDToken( + {"iss": "1", "sub": "1", "aud": "1", "exp": 10000, "iat": 100}, {} + ) + claims.params = {"max_age": 100} + with pytest.raises(MissingClaimError): claims.validate(1000) - def test_validate_auth_time(self): - claims = CodeIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100 - }, {}) - claims.params = {'max_age': 100} - self.assertRaises(MissingClaimError, claims.validate, 1000) - - claims['auth_time'] = 'foo' - self.assertRaises(InvalidClaimError, claims.validate, 1000) - - def test_validate_nonce(self): - claims = CodeIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100 - }, {}) - claims.params = {'nonce': 'foo'} - self.assertRaises(MissingClaimError, claims.validate, 1000) - claims['nonce'] = 'bar' - self.assertRaises(InvalidClaimError, claims.validate, 1000) - claims['nonce'] = 'foo' + claims["auth_time"] = "foo" + with pytest.raises(InvalidClaimError): claims.validate(1000) - def test_validate_amr(self): - claims = CodeIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100, - 'amr': 'invalid' - }, {}) - self.assertRaises(InvalidClaimError, claims.validate, 1000) - - def test_validate_azp(self): - claims = CodeIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100, - }, {}) - claims.params = {'client_id': '2'} - self.assertRaises(MissingClaimError, claims.validate, 1000) - - claims['azp'] = '1' - self.assertRaises(InvalidClaimError, claims.validate, 1000) - - claims['azp'] = '2' + +def test_validate_nonce(): + claims = CodeIDToken( + {"iss": "1", "sub": "1", "aud": "1", "exp": 10000, "iat": 100}, {} + ) + claims.params = {"nonce": "foo"} + with pytest.raises(MissingClaimError): + claims.validate(1000) + claims["nonce"] = "bar" + with pytest.raises(InvalidClaimError): + claims.validate(1000) + claims["nonce"] = "foo" + claims.validate(1000) + + +def test_validate_amr(): + claims = CodeIDToken( + { + "iss": "1", + "sub": "1", + "aud": "1", + "exp": 10000, + "iat": 100, + "amr": "invalid", + }, + {}, + ) + with pytest.raises(InvalidClaimError): + claims.validate(1000) + + +def test_validate_azp(): + claims = CodeIDToken( + { + "iss": "1", + "sub": "1", + "aud": "1", + "exp": 10000, + "iat": 100, + }, + {}, + ) + claims.params = {"client_id": "2"} + with pytest.raises(MissingClaimError): + claims.validate(1000) + + claims["azp"] = "1" + with pytest.raises(InvalidClaimError): + claims.validate(1000) + + claims["azp"] = "2" + claims.validate(1000) + + +def test_validate_at_hash(): + claims = CodeIDToken( + { + "iss": "1", + "sub": "1", + "aud": "1", + "exp": 10000, + "iat": 100, + "at_hash": "a", + }, + {}, + ) + claims.params = {"access_token": "a"} + + # invalid alg will raise too + claims.header = {"alg": "HS222"} + with pytest.raises(InvalidClaimError): claims.validate(1000) - def test_validate_at_hash(self): - claims = CodeIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100, - 'at_hash': 'a' - }, {}) - claims.params = {'access_token': 'a'} - - # invalid alg won't raise - claims.header = {'alg': 'HS222'} + claims.header = {"alg": "HS256"} + with pytest.raises(InvalidClaimError): claims.validate(1000) - claims.header = {'alg': 'HS256'} - self.assertRaises(InvalidClaimError, claims.validate, 1000) - - def test_implicit_id_token(self): - claims = ImplicitIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100, - 'nonce': 'a' - }, {}) - claims.params = {'access_token': 'a'} - self.assertRaises(MissingClaimError, claims.validate, 1000) - - def test_hybrid_id_token(self): - claims = HybridIDToken({ - 'iss': '1', - 'sub': '1', - 'aud': '1', - 'exp': 10000, - 'iat': 100, - 'nonce': 'a' - }, {}) + +def test_implicit_id_token(): + claims = ImplicitIDToken( + { + "iss": "1", + "sub": "1", + "aud": "1", + "exp": 10000, + "iat": 100, + "nonce": "a", + }, + {}, + ) + claims.params = {"access_token": "a"} + with pytest.raises(MissingClaimError): claims.validate(1000) - claims.params = {'code': 'a'} - self.assertRaises(MissingClaimError, claims.validate, 1000) - # invalid alg won't raise - claims.header = {'alg': 'HS222'} - claims['c_hash'] = 'a' +def test_hybrid_id_token(): + claims = HybridIDToken( + { + "iss": "1", + "sub": "1", + "aud": "1", + "exp": 10000, + "iat": 100, + "nonce": "a", + }, + {}, + ) + claims.validate(1000) + + claims.params = {"code": "a"} + with pytest.raises(MissingClaimError): claims.validate(1000) - claims.header = {'alg': 'HS256'} - self.assertRaises(InvalidClaimError, claims.validate, 1000) - - def test_get_claim_cls_by_response_type(self): - cls = get_claim_cls_by_response_type('id_token') - self.assertEqual(cls, ImplicitIDToken) - cls = get_claim_cls_by_response_type('code') - self.assertEqual(cls, CodeIDToken) - cls = get_claim_cls_by_response_type('code id_token') - self.assertEqual(cls, HybridIDToken) - cls = get_claim_cls_by_response_type('none') - self.assertIsNone(cls) - - -class UserInfoTest(unittest.TestCase): - def test_getattribute(self): - user = UserInfo({'sub': '1'}) - self.assertEqual(user.sub, '1') - self.assertIsNone(user.email, None) - self.assertRaises(AttributeError, lambda: user.invalid) + # invalid alg will raise too + claims.header = {"alg": "HS222"} + claims["c_hash"] = "a" + with pytest.raises(InvalidClaimError): + claims.validate(1000) + + claims.header = {"alg": "HS256"} + with pytest.raises(InvalidClaimError): + claims.validate(1000) + + +def test_get_claim_cls_by_response_type(): + cls = get_claim_cls_by_response_type("id_token") + assert cls == ImplicitIDToken + cls = get_claim_cls_by_response_type("code") + assert cls == CodeIDToken + cls = get_claim_cls_by_response_type("code id_token") + assert cls == HybridIDToken + cls = get_claim_cls_by_response_type("none") + assert cls is None + + +def test_userinfo_getattribute(): + user = UserInfo({"sub": "1"}) + assert user.sub == "1" + assert user.email is None + with pytest.raises(AttributeError): + user.invalid # noqa: B018 diff --git a/tests/core/test_oidc/test_discovery.py b/tests/core/test_oidc/test_discovery.py index 043ab11e2..8fd9ce8ab 100644 --- a/tests/core/test_oidc/test_discovery.py +++ b/tests/core/test_oidc/test_discovery.py @@ -1,230 +1,180 @@ -import unittest -from authlib.oidc.discovery import get_well_known_url, OpenIDProviderMetadata - -WELL_KNOWN_URL = '/.well-known/openid-configuration' - - -class WellKnownTest(unittest.TestCase): - def test_no_suffix_issuer(self): - self.assertEqual( - get_well_known_url('https://authlib.org'), - WELL_KNOWN_URL - ) - self.assertEqual( - get_well_known_url('https://authlib.org/'), - WELL_KNOWN_URL - ) - - def test_with_suffix_issuer(self): - self.assertEqual( - get_well_known_url('https://authlib.org/issuer1'), - '/issuer1' + WELL_KNOWN_URL - ) - self.assertEqual( - get_well_known_url('https://authlib.org/a/b/c'), - '/a/b/c' + WELL_KNOWN_URL - ) - - def test_with_external(self): - self.assertEqual( - get_well_known_url('https://authlib.org', external=True), - 'https://authlib.org' + WELL_KNOWN_URL - ) - - -class OpenIDProviderMetadataTest(unittest.TestCase): - def test_validate_jwks_uri(self): - # required - metadata = OpenIDProviderMetadata() - with self.assertRaises(ValueError) as cm: - metadata.validate_jwks_uri() - self.assertEqual('"jwks_uri" is required', str(cm.exception)) - - metadata = OpenIDProviderMetadata({ - 'jwks_uri': 'http://authlib.org/jwks.json' - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_jwks_uri() - self.assertIn('https', str(cm.exception)) - - metadata = OpenIDProviderMetadata({ - 'jwks_uri': 'https://authlib.org/jwks.json' - }) +import pytest + +from authlib.oidc.discovery import OpenIDProviderMetadata +from authlib.oidc.discovery import get_well_known_url + +WELL_KNOWN_URL = "/.well-known/openid-configuration" + + +def test_well_known_no_suffix_issuer(): + assert get_well_known_url("https://provider.test") == WELL_KNOWN_URL + assert get_well_known_url("https://provider.test/") == WELL_KNOWN_URL + + +def test_well_known_with_suffix_issuer(): + assert ( + get_well_known_url("https://provider.test/issuer1") + == "/issuer1" + WELL_KNOWN_URL + ) + assert ( + get_well_known_url("https://provider.test/a/b/c") == "/a/b/c" + WELL_KNOWN_URL + ) + + +def test_well_known_with_external(): + assert ( + get_well_known_url("https://provider.test", external=True) + == "https://provider.test" + WELL_KNOWN_URL + ) + + +def test_validate_jwks_uri(): + # required + metadata = OpenIDProviderMetadata() + with pytest.raises(ValueError, match='"jwks_uri" is required'): metadata.validate_jwks_uri() - def test_validate_acr_values_supported(self): - self._call_validate_array( - 'acr_values_supported', - ['urn:mace:incommon:iap:silver'] - ) - - def test_validate_subject_types_supported(self): - self._call_validate_array( - 'subject_types_supported', - ['pairwise', 'public'], - required=True - ) - self._call_contains_invalid_value( - 'subject_types_supported', - ['invalid'] - ) - - def test_validate_id_token_signing_alg_values_supported(self): - self._call_validate_array( - 'id_token_signing_alg_values_supported', - ['RS256'], required=True, - ) - metadata = OpenIDProviderMetadata({ - 'id_token_signing_alg_values_supported': ['none'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_id_token_signing_alg_values_supported() - self.assertIn('RS256', str(cm.exception)) - - def test_validate_id_token_encryption_alg_values_supported(self): - self._call_validate_array( - 'id_token_encryption_alg_values_supported', - ['A128KW'] - ) - - def test_validate_id_token_encryption_enc_values_supported(self): - self._call_validate_array( - 'id_token_encryption_enc_values_supported', - ['A128GCM'] - ) - - def test_validate_userinfo_signing_alg_values_supported(self): - self._call_validate_array( - 'userinfo_signing_alg_values_supported', - ['RS256'] - ) - - def test_validate_userinfo_encryption_alg_values_supported(self): - self._call_validate_array( - 'userinfo_encryption_alg_values_supported', - ['A128KW'] - ) - - def test_validate_userinfo_encryption_enc_values_supported(self): - self._call_validate_array( - 'userinfo_encryption_enc_values_supported', - ['A128GCM'] - ) - - def test_validate_request_object_signing_alg_values_supported(self): - self._call_validate_array( - 'request_object_signing_alg_values_supported', - ['none', 'RS256'] - ) - metadata = OpenIDProviderMetadata({ - 'request_object_signing_alg_values_supported': ['RS512'] - }) - with self.assertRaises(ValueError) as cm: - metadata.validate_request_object_signing_alg_values_supported() - self.assertIn('SHOULD support none and RS256', str(cm.exception)) - - def test_validate_request_object_encryption_alg_values_supported(self): - self._call_validate_array( - 'request_object_encryption_alg_values_supported', - ['A128KW'] - ) - - def test_validate_request_object_encryption_enc_values_supported(self): - self._call_validate_array( - 'request_object_encryption_enc_values_supported', - ['A128GCM'] - ) - - def test_validate_display_values_supported(self): - self._call_validate_array( - 'display_values_supported', - ['page', 'touch'] - ) - self._call_contains_invalid_value( - 'display_values_supported', - ['invalid'] - ) - - def test_validate_claim_types_supported(self): - self._call_validate_array( - 'claim_types_supported', - ['normal'] - ) - self._call_contains_invalid_value( - 'claim_types_supported', - ['invalid'] - ) - metadata = OpenIDProviderMetadata() - self.assertEqual(metadata.claim_types_supported, ['normal']) - - def test_validate_claims_supported(self): - self._call_validate_array( - 'claims_supported', - ['sub'] - ) - - def test_validate_claims_locales_supported(self): - self._call_validate_array( - 'claims_locales_supported', - ['en-US'] - ) - - def test_validate_claims_parameter_supported(self): - self._call_validate_boolean('claims_parameter_supported') - - def test_validate_request_parameter_supported(self): - self._call_validate_boolean('request_parameter_supported') - - def test_validate_request_uri_parameter_supported(self): - self._call_validate_boolean('request_uri_parameter_supported', True) - - def test_validate_require_request_uri_registration(self): - self._call_validate_boolean('require_request_uri_registration') - - def _call_validate_boolean(self, key, default_value=False): - def _validate(metadata): - getattr(metadata, 'validate_' + key)() - - metadata = OpenIDProviderMetadata() - _validate(metadata) - self.assertEqual(getattr(metadata, key), default_value) + metadata = OpenIDProviderMetadata({"jwks_uri": "http://provider.test/jwks.json"}) + with pytest.raises(ValueError, match="https"): + metadata.validate_jwks_uri() - metadata = OpenIDProviderMetadata({key: 'str'}) - with self.assertRaises(ValueError) as cm: - _validate(metadata) - self.assertIn('MUST be boolean', str(cm.exception)) - metadata = OpenIDProviderMetadata({key: True}) + metadata = OpenIDProviderMetadata({"jwks_uri": "https://provider.test/jwks.json"}) + metadata.validate_jwks_uri() + + +def test_validate_acr_values_supported(): + _call_validate_array("acr_values_supported", ["urn:mace:incommon:iap:silver"]) + + +def test_validate_subject_types_supported(): + _call_validate_array( + "subject_types_supported", ["pairwise", "public"], required=True + ) + _call_contains_invalid_value("subject_types_supported", ["invalid"]) + + +def test_validate_id_token_signing_alg_values_supported(): + _call_validate_array( + "id_token_signing_alg_values_supported", + ["RS256"], + required=True, + ) + metadata = OpenIDProviderMetadata( + {"id_token_signing_alg_values_supported": ["none"]} + ) + with pytest.raises(ValueError, match="RS256"): + metadata.validate_id_token_signing_alg_values_supported() + + +def test_validate_id_token_encryption_alg_values_supported(): + _call_validate_array("id_token_encryption_alg_values_supported", ["A128KW"]) + + +def test_validate_id_token_encryption_enc_values_supported(): + _call_validate_array("id_token_encryption_enc_values_supported", ["A128GCM"]) + + +def test_validate_userinfo_signing_alg_values_supported(): + _call_validate_array("userinfo_signing_alg_values_supported", ["RS256"]) + + +def test_validate_userinfo_encryption_alg_values_supported(): + _call_validate_array("userinfo_encryption_alg_values_supported", ["A128KW"]) + + +def test_validate_userinfo_encryption_enc_values_supported(): + _call_validate_array("userinfo_encryption_enc_values_supported", ["A128GCM"]) + + +def test_validate_request_object_signing_alg_values_supported(): + _call_validate_array( + "request_object_signing_alg_values_supported", ["none", "RS256"] + ) + + +def test_validate_request_object_encryption_alg_values_supported(): + _call_validate_array("request_object_encryption_alg_values_supported", ["A128KW"]) + + +def test_validate_request_object_encryption_enc_values_supported(): + _call_validate_array("request_object_encryption_enc_values_supported", ["A128GCM"]) + + +def test_validate_display_values_supported(): + _call_validate_array("display_values_supported", ["page", "touch"]) + _call_contains_invalid_value("display_values_supported", ["invalid"]) + + +def test_validate_claim_types_supported(): + _call_validate_array("claim_types_supported", ["normal"]) + _call_contains_invalid_value("claim_types_supported", ["invalid"]) + metadata = OpenIDProviderMetadata() + assert metadata.claim_types_supported == ["normal"] + + +def test_validate_claims_supported(): + _call_validate_array("claims_supported", ["sub"]) + + +def test_validate_claims_locales_supported(): + _call_validate_array("claims_locales_supported", ["en-US"]) + + +def test_validate_claims_parameter_supported(): + _call_validate_boolean("claims_parameter_supported") + + +def test_validate_request_parameter_supported(): + _call_validate_boolean("request_parameter_supported") + + +def test_validate_request_uri_parameter_supported(): + _call_validate_boolean("request_uri_parameter_supported", True) + + +def test_validate_require_request_uri_registration(): + _call_validate_boolean("require_request_uri_registration") + + +def _call_validate_boolean(key, default_value=False): + def _validate(metadata): + getattr(metadata, "validate_" + key)() + + metadata = OpenIDProviderMetadata() + _validate(metadata) + assert getattr(metadata, key) == default_value + + metadata = OpenIDProviderMetadata({key: "str"}) + with pytest.raises(ValueError, match="MUST be boolean"): _validate(metadata) - def _call_validate_array(self, key, valid_value, required=False): - def _validate(metadata): - getattr(metadata, 'validate_' + key)() + metadata = OpenIDProviderMetadata({key: True}) + _validate(metadata) - metadata = OpenIDProviderMetadata() - if required: - with self.assertRaises(ValueError) as cm: - _validate(metadata) - self.assertEqual('"{}" is required'.format(key), str(cm.exception)) - else: - _validate(metadata) - # not array - metadata = OpenIDProviderMetadata({key: 'foo'}) - with self.assertRaises(ValueError) as cm: +def _call_validate_array(key, valid_value, required=False): + def _validate(metadata): + getattr(metadata, "validate_" + key)() + + metadata = OpenIDProviderMetadata() + if required: + with pytest.raises(ValueError, match=f'"{key}" is required'): _validate(metadata) - self.assertIn('JSON array', str(cm.exception)) - # valid - metadata = OpenIDProviderMetadata({key: valid_value}) + else: + _validate(metadata) + + # not array + metadata = OpenIDProviderMetadata({key: "foo"}) + with pytest.raises(ValueError, match="JSON array"): _validate(metadata) - def _call_contains_invalid_value(self, key, invalid_value): - metadata = OpenIDProviderMetadata({key: invalid_value}) - with self.assertRaises(ValueError) as cm: - getattr(metadata, 'validate_' + key)() - self.assertEqual( - '"{}" contains invalid values'.format(key), - str(cm.exception) - ) + # valid + metadata = OpenIDProviderMetadata({key: valid_value}) + _validate(metadata) +def _call_contains_invalid_value(key, invalid_value): + metadata = OpenIDProviderMetadata({key: invalid_value}) + with pytest.raises(ValueError, match=f'"{key}" contains invalid values'): + getattr(metadata, "validate_" + key)() diff --git a/tests/core/test_oidc/test_registration.py b/tests/core/test_oidc/test_registration.py new file mode 100644 index 000000000..8916f967a --- /dev/null +++ b/tests/core/test_oidc/test_registration.py @@ -0,0 +1,51 @@ +import pytest +from joserfc.errors import InvalidClaimError + +from authlib.oidc.registration import ClientMetadataClaims + + +def test_request_uris(): + claims = ClientMetadataClaims( + {"request_uris": ["https://client.test/request_uris"]}, {} + ) + claims.validate() + + claims = ClientMetadataClaims({"request_uris": ["invalid"]}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + + +def test_initiate_login_uri(): + claims = ClientMetadataClaims( + {"initiate_login_uri": "https://client.test/initiate_login_uri"}, {} + ) + claims.validate() + + claims = ClientMetadataClaims({"initiate_login_uri": "invalid"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + + +def test_token_endpoint_auth_signing_alg(): + claims = ClientMetadataClaims({"token_endpoint_auth_signing_alg": "RSA256"}, {}) + claims.validate() + + # The value none MUST NOT be used. + claims = ClientMetadataClaims({"token_endpoint_auth_signing_alg": "none"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + + +def test_id_token_signed_response_alg(): + claims = ClientMetadataClaims({"id_token_signed_response_alg": "RSA256"}, {}) + claims.validate() + + +def test_default_max_age(): + claims = ClientMetadataClaims({"default_max_age": 1234}, {}) + claims.validate() + + # The value none MUST NOT be used. + claims = ClientMetadataClaims({"default_max_age": "invalid"}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() diff --git a/tests/core/test_oidc/test_rpinitiated.py b/tests/core/test_oidc/test_rpinitiated.py new file mode 100644 index 000000000..5ad9d235b --- /dev/null +++ b/tests/core/test_oidc/test_rpinitiated.py @@ -0,0 +1,91 @@ +import pytest +from joserfc.errors import InvalidClaimError + +from authlib.oidc import discovery +from authlib.oidc import rpinitiated +from authlib.oidc.rpinitiated import ClientMetadataClaims + + +@pytest.fixture +def valid_oidc_metadata(): + return { + "issuer": "https://provider.test", + "authorization_endpoint": "https://provider.test/authorize", + "token_endpoint": "https://provider.test/token", + "jwks_uri": "https://provider.test/jwks.json", + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + + +def test_validate_end_session_endpoint(valid_oidc_metadata): + valid_oidc_metadata["end_session_endpoint"] = "https://provider.test/logout" + metadata = discovery.OpenIDProviderMetadata(valid_oidc_metadata) + metadata.validate(metadata_classes=[rpinitiated.OpenIDProviderMetadata]) + + +def test_validate_end_session_endpoint_missing(valid_oidc_metadata): + """end_session_endpoint is optional.""" + metadata = discovery.OpenIDProviderMetadata(valid_oidc_metadata) + metadata.validate(metadata_classes=[rpinitiated.OpenIDProviderMetadata]) + + +def test_validate_end_session_endpoint_insecure(valid_oidc_metadata): + valid_oidc_metadata["end_session_endpoint"] = "http://provider.test/logout" + metadata = discovery.OpenIDProviderMetadata(valid_oidc_metadata) + with pytest.raises(ValueError, match="https"): + metadata.validate(metadata_classes=[rpinitiated.OpenIDProviderMetadata]) + + +def test_post_logout_redirect_uris(): + claims = ClientMetadataClaims( + {"post_logout_redirect_uris": ["https://client.test/logout"]}, {} + ) + claims.validate() + + claims = ClientMetadataClaims( + { + "post_logout_redirect_uris": [ + "https://client.test/logout", + "https://client.test/logged-out", + ] + }, + {}, + ) + claims.validate() + + claims = ClientMetadataClaims({"post_logout_redirect_uris": ["invalid"]}, {}) + with pytest.raises(InvalidClaimError): + claims.validate() + + +def test_post_logout_redirect_uris_empty(): + """Empty list should be valid.""" + claims = ClientMetadataClaims({"post_logout_redirect_uris": []}, {}) + claims.validate() + + +def test_post_logout_redirect_uris_insecure_public_client(): + """HTTP URIs should be rejected for public clients.""" + claims = ClientMetadataClaims( + { + "post_logout_redirect_uris": ["http://client.test/logout"], + "token_endpoint_auth_method": "none", + }, + {}, + ) + with pytest.raises(ValueError, match="public clients"): + claims.validate() + + +def test_post_logout_redirect_uris_insecure_confidential_client(): + """HTTP URIs should be accepted for confidential clients.""" + claims = ClientMetadataClaims( + { + "post_logout_redirect_uris": ["http://client.test/logout"], + "token_endpoint_auth_method": "client_secret_basic", + }, + {}, + ) + claims.validate() diff --git a/tests/core/test_oidc/test_utils.py b/tests/core/test_oidc/test_utils.py new file mode 100644 index 000000000..b83ae9f22 --- /dev/null +++ b/tests/core/test_oidc/test_utils.py @@ -0,0 +1,47 @@ +import pytest + +from authlib.jose.rfc7518.ec_key import ECKey +from authlib.jose.rfc7518.oct_key import OctKey +from authlib.jose.rfc7518.rsa_key import RSAKey +from authlib.jose.rfc8037.okp_key import OKPKey +from authlib.oidc.core import UserInfo +from authlib.oidc.core.grants.util import generate_id_token + +hmac_key = OctKey.generate_key(256) +rsa_key = RSAKey.generate_key(2048, is_private=True) +ec_key = ECKey.generate_key("P-256", is_private=True) +okp_key = OKPKey.generate_key("Ed25519", is_private=True) +ec_secp256k1_key = ECKey.generate_key("secp256k1", is_private=True) + + +@pytest.mark.parametrize( + "alg,key", + [ + ("none", None), + ("HS256", hmac_key), + ("HS384", hmac_key), + ("HS512", hmac_key), + ("RS256", rsa_key), + ("RS384", rsa_key), + ("RS512", rsa_key), + ("ES256", ec_key), + ("PS256", rsa_key), + ("PS384", rsa_key), + ("PS512", rsa_key), + ("EdDSA", okp_key), + ("ES256K", ec_secp256k1_key), + ], +) +def test_generate_id_token(alg, key): + token = {"access_token": "test_token"} + user_info = UserInfo({"sub": "123"}) + + result = generate_id_token( + token=token, + user_info=user_info, + key=key, + iss="https://provider.test", + aud="client_id", + alg=alg, + ) + assert result is not None diff --git a/tests/core/test_requests_client/test_assertion_session.py b/tests/core/test_requests_client/test_assertion_session.py deleted file mode 100644 index 7b89b7cda..000000000 --- a/tests/core/test_requests_client/test_assertion_session.py +++ /dev/null @@ -1,66 +0,0 @@ -import mock -import time -from unittest import TestCase -from authlib.integrations.requests_client import AssertionSession - - -class AssertionSessionTest(TestCase): - - def setUp(self): - self.token = { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, - } - - def test_refresh_token(self): - def verifier(r, **kwargs): - resp = mock.MagicMock() - if r.url == 'https://i.b/token': - self.assertIn('assertion=', r.body) - resp.json = lambda: self.token - return resp - - sess = AssertionSession( - 'https://i.b/token', - grant_type=AssertionSession.JWT_BEARER_GRANT_TYPE, - issuer='foo', - subject='foo', - audience='foo', - alg='HS256', - key='secret', - ) - sess.send = verifier - sess.get('https://i.b') - - # trigger more case - now = int(time.time()) - sess = AssertionSession( - 'https://i.b/token', - issuer='foo', - subject=None, - audience='foo', - issued_at=now, - expires_at=now + 3600, - header={'alg': 'HS256'}, - key='secret', - scope='email', - claims={'test_mode': 'true'} - ) - sess.send = verifier - sess.get('https://i.b') - # trigger for branch test case - sess.get('https://i.b') - - def test_without_alg(self): - sess = AssertionSession( - 'https://i.b/token', - grant_type=AssertionSession.JWT_BEARER_GRANT_TYPE, - issuer='foo', - subject='foo', - audience='foo', - key='secret', - ) - self.assertRaises(ValueError, sess.get, 'https://i.b') diff --git a/tests/core/test_requests_client/test_oauth1_session.py b/tests/core/test_requests_client/test_oauth1_session.py deleted file mode 100644 index 2378d930d..000000000 --- a/tests/core/test_requests_client/test_oauth1_session.py +++ /dev/null @@ -1,275 +0,0 @@ -from __future__ import unicode_literals, print_function -import mock -import requests -from unittest import TestCase -from io import StringIO - -from authlib.oauth1 import ( - SIGNATURE_PLAINTEXT, - SIGNATURE_RSA_SHA1, - SIGNATURE_TYPE_BODY, - SIGNATURE_TYPE_QUERY, -) -from authlib.oauth1.rfc5849.util import escape -from authlib.common.encoding import to_unicode, unicode_type -from authlib.integrations.requests_client import OAuth1Session, OAuthError -from tests.client_base import mock_text_response -from tests.util import read_file_path - - -TEST_RSA_OAUTH_SIGNATURE = ( - "j8WF8PGjojT82aUDd2EL%2Bz7HCoHInFzWUpiEKMCy%2BJ2cYHWcBS7mXlmFDLgAKV0" - "P%2FyX4TrpXODYnJ6dRWdfghqwDpi%2FlQmB2jxCiGMdJoYxh3c5zDf26gEbGdP6D7O" - "Ssp5HUnzH6sNkmVjuE%2FxoJcHJdc23H6GhOs7VJ2LWNdbhKWP%2FMMlTrcoQDn8lz" - "%2Fb24WsJ6ae1txkUzpFOOlLM8aTdNtGL4OtsubOlRhNqnAFq93FyhXg0KjzUyIZzmMX" - "9Vx90jTks5QeBGYcLE0Op2iHb2u%2FO%2BEgdwFchgEwE5LgMUyHUI4F3Wglp28yHOAM" - "jPkI%2FkWMvpxtMrU3Z3KN31WQ%3D%3D" -) - - -class OAuth1SessionTest(TestCase): - - def test_no_client_id(self): - self.assertRaises(ValueError, lambda: OAuth1Session(None)) - - def test_signature_types(self): - def verify_signature(getter): - def fake_send(r, **kwargs): - signature = to_unicode(getter(r)) - self.assertIn('oauth_signature', signature) - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - return resp - return fake_send - - header = OAuth1Session('foo') - header.send = verify_signature(lambda r: r.headers['Authorization']) - header.post('https://i.b') - - query = OAuth1Session('foo', signature_type=SIGNATURE_TYPE_QUERY) - query.send = verify_signature(lambda r: r.url) - query.post('https://i.b') - - body = OAuth1Session('foo', signature_type=SIGNATURE_TYPE_BODY) - headers = {'Content-Type': 'application/x-www-form-urlencoded'} - body.send = verify_signature(lambda r: r.body) - body.post('https://i.b', headers=headers, data='') - - @mock.patch('authlib.oauth1.rfc5849.client_auth.generate_timestamp') - @mock.patch('authlib.oauth1.rfc5849.client_auth.generate_nonce') - def test_signature_methods(self, generate_nonce, generate_timestamp): - generate_nonce.return_value = 'abc' - generate_timestamp.return_value = '123' - - signature = ', '.join([ - 'OAuth oauth_nonce="abc"', - 'oauth_timestamp="123"', - 'oauth_version="1.0"', - 'oauth_signature_method="HMAC-SHA1"', - 'oauth_consumer_key="foo"', - 'oauth_signature="h2sRqLArjhlc5p3FTkuNogVHlKE%3D"' - ]) - auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b') - - signature = ( - 'OAuth ' - 'oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' - 'oauth_signature_method="PLAINTEXT", oauth_consumer_key="foo", ' - 'oauth_signature="%26"' - ) - auth = OAuth1Session('foo', signature_method=SIGNATURE_PLAINTEXT) - auth.send = self.verify_signature(signature) - auth.post('https://i.b') - - signature = ( - 'OAuth ' - 'oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' - 'oauth_signature_method="RSA-SHA1", oauth_consumer_key="foo", ' - 'oauth_signature="{sig}"' - ).format(sig=TEST_RSA_OAUTH_SIGNATURE) - - rsa_key = read_file_path('rsa_private.pem') - auth = OAuth1Session( - 'foo', signature_method=SIGNATURE_RSA_SHA1, rsa_key=rsa_key) - auth.send = self.verify_signature(signature) - auth.post('https://i.b') - - @mock.patch('authlib.oauth1.rfc5849.client_auth.generate_timestamp') - @mock.patch('authlib.oauth1.rfc5849.client_auth.generate_nonce') - def test_binary_upload(self, generate_nonce, generate_timestamp): - generate_nonce.return_value = 'abc' - generate_timestamp.return_value = '123' - fake_xml = StringIO('hello world') - headers = {'Content-Type': 'application/xml'} - signature = ( - 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' - 'oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", ' - 'oauth_signature="h2sRqLArjhlc5p3FTkuNogVHlKE%3D"' - ) - auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b', headers=headers, files=[('fake', fake_xml)]) - - @mock.patch('authlib.oauth1.rfc5849.client_auth.generate_timestamp') - @mock.patch('authlib.oauth1.rfc5849.client_auth.generate_nonce') - def test_nonascii(self, generate_nonce, generate_timestamp): - generate_nonce.return_value = 'abc' - generate_timestamp.return_value = '123' - signature = ( - 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' - 'oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", ' - 'oauth_signature="W0haoue5IZAZoaJiYCtfqwMf8x8%3D"' - ) - auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b?cjk=%E5%95%A6%E5%95%A6') - - def test_redirect_uri(self): - sess = OAuth1Session('foo') - self.assertIsNone(sess.redirect_uri) - url = 'https://i.b' - sess.redirect_uri = url - self.assertEqual(sess.redirect_uri, url) - - def test_set_token(self): - sess = OAuth1Session('foo') - try: - sess.token = {} - except OAuthError as exc: - self.assertEqual(exc.error, 'missing_token') - - sess.token = {'oauth_token': 'a', 'oauth_token_secret': 'b'} - self.assertIsNone(sess.token['oauth_verifier']) - sess.token = {'oauth_token': 'a', 'oauth_verifier': 'c'} - self.assertEqual(sess.token['oauth_token_secret'], 'b') - self.assertEqual(sess.token['oauth_verifier'], 'c') - - sess.token = None - self.assertIsNone(sess.token['oauth_token']) - self.assertIsNone(sess.token['oauth_token_secret']) - self.assertIsNone(sess.token['oauth_verifier']) - - def test_create_authorization_url(self): - auth = OAuth1Session('foo') - url = 'https://example.comm/authorize' - token = 'asluif023sf' - auth_url = auth.create_authorization_url(url, request_token=token) - self.assertEqual(auth_url, url + '?oauth_token=' + token) - redirect_uri = 'https://c.b' - auth = OAuth1Session('foo', redirect_uri=redirect_uri) - auth_url = auth.create_authorization_url(url, request_token=token) - self.assertIn(escape(redirect_uri), auth_url) - - def test_parse_response_url(self): - url = 'https://i.b/callback?oauth_token=foo&oauth_verifier=bar' - auth = OAuth1Session('foo') - resp = auth.parse_authorization_response(url) - self.assertEqual(resp['oauth_token'], 'foo') - self.assertEqual(resp['oauth_verifier'], 'bar') - for k, v in resp.items(): - self.assertTrue(isinstance(k, unicode_type)) - self.assertTrue(isinstance(v, unicode_type)) - - def test_fetch_request_token(self): - auth = OAuth1Session('foo') - auth.send = mock_text_response('oauth_token=foo') - resp = auth.fetch_request_token('https://example.com/token') - self.assertEqual(resp['oauth_token'], 'foo') - for k, v in resp.items(): - self.assertTrue(isinstance(k, unicode_type)) - self.assertTrue(isinstance(v, unicode_type)) - - resp = auth.fetch_request_token('https://example.com/token', realm='A') - self.assertEqual(resp['oauth_token'], 'foo') - resp = auth.fetch_request_token('https://example.com/token', realm=['A', 'B']) - self.assertEqual(resp['oauth_token'], 'foo') - - def test_fetch_request_token_with_optional_arguments(self): - auth = OAuth1Session('foo') - auth.send = mock_text_response('oauth_token=foo') - resp = auth.fetch_request_token('https://example.com/token', - verify=False, stream=True) - self.assertEqual(resp['oauth_token'], 'foo') - for k, v in resp.items(): - self.assertTrue(isinstance(k, unicode_type)) - self.assertTrue(isinstance(v, unicode_type)) - - def test_fetch_access_token(self): - auth = OAuth1Session('foo', verifier='bar') - auth.send = mock_text_response('oauth_token=foo') - resp = auth.fetch_access_token('https://example.com/token') - self.assertEqual(resp['oauth_token'], 'foo') - for k, v in resp.items(): - self.assertTrue(isinstance(k, unicode_type)) - self.assertTrue(isinstance(v, unicode_type)) - - auth = OAuth1Session('foo', verifier='bar') - auth.send = mock_text_response('{"oauth_token":"foo"}') - resp = auth.fetch_access_token('https://example.com/token') - self.assertEqual(resp['oauth_token'], 'foo') - - auth = OAuth1Session('foo') - auth.send = mock_text_response('oauth_token=foo') - resp = auth.fetch_access_token( - 'https://example.com/token', verifier='bar') - self.assertEqual(resp['oauth_token'], 'foo') - - def test_fetch_access_token_with_optional_arguments(self): - auth = OAuth1Session('foo', verifier='bar') - auth.send = mock_text_response('oauth_token=foo') - resp = auth.fetch_access_token('https://example.com/token', - verify=False, stream=True) - self.assertEqual(resp['oauth_token'], 'foo') - for k, v in resp.items(): - self.assertTrue(isinstance(k, unicode_type)) - self.assertTrue(isinstance(v, unicode_type)) - - def _test_fetch_access_token_raises_error(self, session): - """Assert that an error is being raised whenever there's no verifier - passed in to the client. - """ - session.send = mock_text_response('oauth_token=foo') - - # Use a try-except block so that we can assert on the exception message - # being raised and also keep the Python2.6 compatibility where - # assertRaises is not a context manager. - try: - session.fetch_access_token('https://example.com/token') - except OAuthError as exc: - self.assertEqual(exc.error, 'missing_verifier') - - def test_fetch_token_invalid_response(self): - auth = OAuth1Session('foo') - auth.send = mock_text_response('not valid urlencoded response!') - self.assertRaises( - ValueError, auth.fetch_request_token, 'https://example.com/token') - - for code in (400, 401, 403): - auth.send = mock_text_response('valid=response', code) - # use try/catch rather than self.assertRaises, so we can - # assert on the properties of the exception - try: - auth.fetch_request_token('https://example.com/token') - except OAuthError as err: - self.assertEqual(err.error, 'fetch_token_denied') - else: # no exception raised - self.fail("ValueError not raised") - - def test_fetch_access_token_missing_verifier(self): - self._test_fetch_access_token_raises_error(OAuth1Session('foo')) - - def test_fetch_access_token_has_verifier_is_none(self): - session = OAuth1Session('foo') - session.auth.verifier = None - self._test_fetch_access_token_raises_error(session) - - def verify_signature(self, signature): - def fake_send(r, **kwargs): - auth_header = to_unicode(r.headers['Authorization']) - self.assertEqual(auth_header, signature) - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - return resp - return fake_send diff --git a/tests/core/test_requests_client/test_oauth2_session.py b/tests/core/test_requests_client/test_oauth2_session.py deleted file mode 100644 index 3e29629f6..000000000 --- a/tests/core/test_requests_client/test_oauth2_session.py +++ /dev/null @@ -1,509 +0,0 @@ -from __future__ import unicode_literals -import mock -import time -from copy import deepcopy -from unittest import TestCase -from authlib.common.security import generate_token -from authlib.common.urls import url_encode, add_params_to_uri -from authlib.integrations.requests_client import OAuth2Session, OAuthError -from authlib.oauth2.rfc6749 import ( - MismatchingStateException, -) -from authlib.oauth2.rfc7523 import ClientSecretJWT, PrivateKeyJWT -from tests.util import read_file_path -from tests.client_base import mock_json_response - - -class OAuth2SessionTest(TestCase): - - def setUp(self): - self.token = { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, - } - self.client_id = 'foo' - - def test_invalid_token_type(self): - token = { - 'token_type': 'invalid', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, - } - with OAuth2Session(self.client_id, token=token) as sess: - self.assertRaises(OAuthError, sess.get, 'https://i.b') - - def test_add_token_to_header(self): - token = 'Bearer ' + self.token['access_token'] - - def verifier(r, **kwargs): - auth_header = r.headers.get(str('Authorization'), None) - self.assertEqual(auth_header, token) - resp = mock.MagicMock() - return resp - - sess = OAuth2Session(client_id=self.client_id, token=self.token) - sess.send = verifier - sess.get('https://i.b') - - def test_add_token_to_body(self): - def verifier(r, **kwargs): - self.assertIn(self.token['access_token'], r.body) - resp = mock.MagicMock() - return resp - - sess = OAuth2Session( - client_id=self.client_id, - token=self.token, - token_placement='body' - ) - sess.send = verifier - sess.post('https://i.b') - - def test_add_token_to_uri(self): - def verifier(r, **kwargs): - self.assertIn(self.token['access_token'], r.url) - resp = mock.MagicMock() - return resp - - sess = OAuth2Session( - client_id=self.client_id, - token=self.token, - token_placement='uri' - ) - sess.send = verifier - sess.get('https://i.b') - - def test_create_authorization_url(self): - url = 'https://example.com/authorize?foo=bar' - - sess = OAuth2Session(client_id=self.client_id) - auth_url, state = sess.create_authorization_url(url) - self.assertIn(state, auth_url) - self.assertIn(self.client_id, auth_url) - self.assertIn('response_type=code', auth_url) - - sess = OAuth2Session(client_id=self.client_id, prompt='none') - auth_url, state = sess.create_authorization_url( - url, state='foo', redirect_uri='https://i.b', scope='profile') - self.assertEqual(state, 'foo') - self.assertIn('i.b', auth_url) - self.assertIn('profile', auth_url) - self.assertIn('prompt=none', auth_url) - - def test_code_challenge(self): - sess = OAuth2Session(client_id=self.client_id, code_challenge_method='S256') - - url = 'https://example.com/authorize' - auth_url, _ = sess.create_authorization_url( - url, code_verifier=generate_token(48)) - self.assertIn('code_challenge', auth_url) - self.assertIn('code_challenge_method=S256', auth_url) - - def test_token_from_fragment(self): - sess = OAuth2Session(self.client_id) - response_url = 'https://i.b/callback#' + url_encode(self.token.items()) - self.assertEqual(sess.token_from_fragment(response_url), self.token) - token = sess.fetch_token(authorization_response=response_url) - self.assertEqual(token, self.token) - - def test_fetch_token_post(self): - url = 'https://example.com/token' - - def fake_send(r, **kwargs): - self.assertIn('code=v', r.body) - self.assertIn('client_id=', r.body) - self.assertIn('grant_type=authorization_code', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess = OAuth2Session(client_id=self.client_id) - sess.send = fake_send - self.assertEqual( - sess.fetch_token( - url, authorization_response='https://i.b/?code=v'), - self.token) - - sess = OAuth2Session( - client_id=self.client_id, - token_endpoint_auth_method='none', - ) - sess.send = fake_send - token = sess.fetch_token(url, code='v') - self.assertEqual(token, self.token) - - error = {'error': 'invalid_request'} - sess = OAuth2Session(client_id=self.client_id, token=self.token) - sess.send = mock_json_response(error) - self.assertRaises(OAuthError, sess.fetch_access_token, url) - - def test_fetch_token_get(self): - url = 'https://example.com/token' - - def fake_send(r, **kwargs): - self.assertIn('code=v', r.url) - self.assertIn('grant_type=authorization_code', r.url) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess = OAuth2Session(client_id=self.client_id) - sess.send = fake_send - token = sess.fetch_token( - url, authorization_response='https://i.b/?code=v', method='GET') - self.assertEqual(token, self.token) - - sess = OAuth2Session( - client_id=self.client_id, - token_endpoint_auth_method='none', - ) - sess.send = fake_send - token = sess.fetch_token(url, code='v', method='GET') - self.assertEqual(token, self.token) - - token = sess.fetch_token(url + '?q=a', code='v', method='GET') - self.assertEqual(token, self.token) - - def test_token_auth_method_client_secret_post(self): - url = 'https://example.com/token' - - def fake_send(r, **kwargs): - self.assertIn('code=v', r.body) - self.assertIn('client_id=', r.body) - self.assertIn('client_secret=bar', r.body) - self.assertIn('grant_type=authorization_code', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess = OAuth2Session( - client_id=self.client_id, - client_secret='bar', - token_endpoint_auth_method='client_secret_post', - ) - sess.send = fake_send - token = sess.fetch_token(url, code='v') - self.assertEqual(token, self.token) - - def test_access_token_response_hook(self): - url = 'https://example.com/token' - - def access_token_response_hook(resp): - self.assertEqual(resp.json(), self.token) - return resp - - sess = OAuth2Session(client_id=self.client_id, token=self.token) - sess.register_compliance_hook( - 'access_token_response', - access_token_response_hook - ) - sess.send = mock_json_response(self.token) - self.assertEqual(sess.fetch_token(url), self.token) - - def test_password_grant_type(self): - url = 'https://example.com/token' - - def fake_send(r, **kwargs): - self.assertIn('username=v', r.body) - self.assertIn('grant_type=password', r.body) - self.assertIn('scope=profile', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess = OAuth2Session(client_id=self.client_id, scope='profile') - sess.send = fake_send - token = sess.fetch_token(url, username='v', password='v') - self.assertEqual(token, self.token) - - def test_client_credentials_type(self): - url = 'https://example.com/token' - - def fake_send(r, **kwargs): - self.assertIn('grant_type=client_credentials', r.body) - self.assertIn('scope=profile', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess = OAuth2Session( - client_id=self.client_id, - client_secret='v', - scope='profile', - ) - sess.send = fake_send - token = sess.fetch_token(url) - self.assertEqual(token, self.token) - - def test_cleans_previous_token_before_fetching_new_one(self): - """Makes sure the previous token is cleaned before fetching a new one. - The reason behind it is that, if the previous token is expired, this - method shouldn't fail with a TokenExpiredError, since it's attempting - to get a new one (which shouldn't be expired). - """ - now = int(time.time()) - new_token = deepcopy(self.token) - past = now - 7200 - self.token['expires_at'] = past - new_token['expires_at'] = now + 3600 - url = 'https://example.com/token' - - with mock.patch('time.time', lambda: now): - sess = OAuth2Session(client_id=self.client_id, token=self.token) - sess.send = mock_json_response(new_token) - self.assertEqual(sess.fetch_token(url), new_token) - - def test_mis_match_state(self): - sess = OAuth2Session('foo') - self.assertRaises( - MismatchingStateException, - sess.fetch_token, - 'https://i.b/token', - authorization_response='https://i.b/no-state?code=abc', - state='somestate', - ) - - def test_token_status(self): - token = dict(access_token='a', token_type='bearer', expires_at=100) - sess = OAuth2Session('foo', token=token) - - self.assertTrue(sess.token.is_expired) - - def test_token_expired(self): - token = dict(access_token='a', token_type='bearer', expires_at=100) - sess = OAuth2Session('foo', token=token) - self.assertRaises( - OAuthError, - sess.get, - 'https://i.b/token', - ) - - def test_missing_token(self): - sess = OAuth2Session('foo') - self.assertRaises( - OAuthError, - sess.get, - 'https://i.b/token', - ) - - def test_register_compliance_hook(self): - sess = OAuth2Session('foo') - self.assertRaises( - ValueError, - sess.register_compliance_hook, - 'invalid_hook', - lambda o: o, - ) - - def protected_request(url, headers, data): - self.assertIn('Authorization', headers) - return url, headers, data - - sess = OAuth2Session('foo', token=self.token) - sess.register_compliance_hook( - 'protected_request', - protected_request, - ) - sess.send = mock_json_response({'name': 'a'}) - sess.get('https://i.b/user') - - def test_auto_refresh_token(self): - - def _update_token(token, refresh_token=None, access_token=None): - self.assertEqual(refresh_token, 'b') - self.assertEqual(token, self.token) - - update_token = mock.Mock(side_effect=_update_token) - old_token = dict( - access_token='a', refresh_token='b', - token_type='bearer', expires_at=100 - ) - sess = OAuth2Session( - 'foo', token=old_token, - token_endpoint='https://i.b/token', - update_token=update_token, - ) - sess.send = mock_json_response(self.token) - sess.get('https://i.b/user') - self.assertTrue(update_token.called) - - def test_auto_refresh_token2(self): - - def _update_token(token, refresh_token=None, access_token=None): - self.assertEqual(access_token, 'a') - self.assertEqual(token, self.token) - - update_token = mock.Mock(side_effect=_update_token) - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - - sess = OAuth2Session( - 'foo', token=old_token, - token_endpoint='https://i.b/token', - grant_type='client_credentials', - ) - sess.send = mock_json_response(self.token) - sess.get('https://i.b/user') - self.assertFalse(update_token.called) - - sess = OAuth2Session( - 'foo', token=old_token, - token_endpoint='https://i.b/token', - grant_type='client_credentials', - update_token=update_token, - ) - sess.send = mock_json_response(self.token) - sess.get('https://i.b/user') - self.assertTrue(update_token.called) - - def test_revoke_token(self): - sess = OAuth2Session('a') - answer = {'status': 'ok'} - sess.send = mock_json_response(answer) - resp = sess.revoke_token('https://i.b/token', 'hi') - self.assertEqual(resp.json(), answer) - resp = sess.revoke_token( - 'https://i.b/token', 'hi', - token_type_hint='access_token' - ) - self.assertEqual(resp.json(), answer) - - def revoke_token_request(url, headers, data): - self.assertEqual(url, 'https://i.b/token') - return url, headers, data - - sess.register_compliance_hook( - 'revoke_token_request', - revoke_token_request, - ) - sess.revoke_token( - 'https://i.b/token', 'hi', - body='', - token_type_hint='access_token' - ) - - def test_introspect_token(self): - sess = OAuth2Session('a') - answer = { - "active": True, - "client_id": "l238j323ds-23ij4", - "username": "jdoe", - "scope": "read write dolphin", - "sub": "Z5O3upPC88QrAjx00dis", - "aud": "https://protected.example.net/resource", - "iss": "https://server.example.com/", - "exp": 1419356238, - "iat": 1419350238 - } - sess.send = mock_json_response(answer) - resp = sess.introspect_token('https://i.b/token', 'hi') - self.assertEqual(resp.json(), answer) - - def test_client_secret_jwt(self): - sess = OAuth2Session( - 'id', 'secret', - token_endpoint_auth_method='client_secret_jwt' - ) - sess.register_client_auth_method(ClientSecretJWT()) - - def fake_send(r, **kwargs): - self.assertIn('client_assertion=', r.body) - self.assertIn('client_assertion_type=', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess.send = fake_send - token = sess.fetch_token('https://i.b/token') - self.assertEqual(token, self.token) - - def test_client_secret_jwt2(self): - sess = OAuth2Session( - 'id', 'secret', - token_endpoint_auth_method=ClientSecretJWT(), - ) - - def fake_send(r, **kwargs): - self.assertIn('client_assertion=', r.body) - self.assertIn('client_assertion_type=', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess.send = fake_send - token = sess.fetch_token('https://i.b/token') - self.assertEqual(token, self.token) - - def test_private_key_jwt(self): - client_secret = read_file_path('rsa_private.pem') - sess = OAuth2Session( - 'id', client_secret, - token_endpoint_auth_method='private_key_jwt' - ) - sess.register_client_auth_method(PrivateKeyJWT()) - - def fake_send(r, **kwargs): - self.assertIn('client_assertion=', r.body) - self.assertIn('client_assertion_type=', r.body) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess.send = fake_send - token = sess.fetch_token('https://i.b/token') - self.assertEqual(token, self.token) - - def test_custom_client_auth_method(self): - def auth_client(client, method, uri, headers, body): - uri = add_params_to_uri(uri, [ - ('client_id', client.client_id), - ('client_secret', client.client_secret), - ]) - uri = uri + '&' + body - body = '' - return uri, headers, body - - sess = OAuth2Session( - 'id', 'secret', - token_endpoint_auth_method='client_secret_uri' - ) - sess.register_client_auth_method(('client_secret_uri', auth_client)) - - def fake_send(r, **kwargs): - self.assertIn('client_id=', r.url) - self.assertIn('client_secret=', r.url) - resp = mock.MagicMock() - resp.json = lambda: self.token - return resp - - sess.send = fake_send - token = sess.fetch_token('https://i.b/token') - self.assertEqual(token, self.token) - - def test_use_client_token_auth(self): - import requests - - token = 'Bearer ' + self.token['access_token'] - - def verifier(r, **kwargs): - auth_header = r.headers.get(str('Authorization'), None) - self.assertEqual(auth_header, token) - resp = mock.MagicMock() - return resp - - client = OAuth2Session( - client_id=self.client_id, - token=self.token - ) - - sess = requests.Session() - sess.send = verifier - sess.get('https://i.b', auth=client.token_auth) diff --git a/tests/django/conftest.py b/tests/django/conftest.py new file mode 100644 index 000000000..2fbab8774 --- /dev/null +++ b/tests/django/conftest.py @@ -0,0 +1,8 @@ +import pytest + +from tests.django_helper import RequestClient + + +@pytest.fixture +def factory(): + return RequestClient() diff --git a/tests/django/settings.py b/tests/django/settings.py deleted file mode 100644 index 92136d048..000000000 --- a/tests/django/settings.py +++ /dev/null @@ -1,39 +0,0 @@ -SECRET_KEY = 'django-secret' - -DATABASES = { - "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": "example.sqlite", - } -} - -MIDDLEWARE = [ - 'django.contrib.sessions.middleware.SessionMiddleware' -] - -SESSION_ENGINE = 'django.contrib.sessions.backends.cache' - -CACHES = { - 'default': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', - 'LOCATION': 'unique-snowflake', - } -} - -INSTALLED_APPS=[ - 'django.contrib.contenttypes', - 'django.contrib.auth', - 'tests.django.test_oauth1', - 'tests.django.test_oauth2', -] - -AUTHLIB_OAUTH_CLIENTS = { - 'dev_overwrite': { - 'client_id': 'dev-client-id', - 'client_secret': 'dev-client-secret', - 'access_token_params': { - 'foo': 'foo-1', - 'bar': 'bar-2' - } - } -} diff --git a/tests/django/test_client/test_oauth_client.py b/tests/django/test_client/test_oauth_client.py deleted file mode 100644 index f05809031..000000000 --- a/tests/django/test_client/test_oauth_client.py +++ /dev/null @@ -1,294 +0,0 @@ -from __future__ import unicode_literals, print_function - -import mock -from django.test import override_settings -from authlib.integrations.django_client import OAuth, OAuthError -from tests.django.base import TestCase -from tests.client_base import ( - mock_send_value, - get_bearer_token -) - -dev_client = { - 'client_id': 'dev-key', - 'client_secret': 'dev-secret' -} - - -class DjangoOAuthTest(TestCase): - def test_register_remote_app(self): - oauth = OAuth() - self.assertRaises(AttributeError, lambda: oauth.dev) - - oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - request_token_url='https://i.b/reqeust-token', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - self.assertEqual(oauth.dev.name, 'dev') - self.assertEqual(oauth.dev.client_id, 'dev') - - def test_register_with_overwrite(self): - oauth = OAuth() - oauth.register( - 'dev_overwrite', - overwrite=True, - client_id='dev', - client_secret='dev', - request_token_url='https://i.b/reqeust-token', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - access_token_params={ - 'foo': 'foo' - }, - authorize_url='https://i.b/authorize' - ) - self.assertEqual(oauth.dev_overwrite.client_id, 'dev-client-id') - self.assertEqual( - oauth.dev_overwrite.access_token_params['foo'], 'foo-1') - - @override_settings(AUTHLIB_OAUTH_CLIENTS={'dev': dev_client}) - def test_register_from_settings(self): - oauth = OAuth() - oauth.register('dev') - self.assertEqual(oauth.dev.client_id, 'dev-key') - self.assertEqual(oauth.dev.client_secret, 'dev-secret') - - def test_oauth1_authorize(self): - request = self.factory.get('/login') - request.session = self.factory.session - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - request_token_url='https://i.b/reqeust-token', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - ) - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value('oauth_token=foo&oauth_verifier=baz') - - resp = client.authorize_redirect(request) - self.assertEqual(resp.status_code, 302) - url = resp.get('Location') - self.assertIn('oauth_token=foo', url) - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value('oauth_token=a&oauth_token_secret=b') - token = client.authorize_access_token(request) - self.assertEqual(token['oauth_token'], 'a') - - def test_oauth2_authorize(self): - request = self.factory.get('/login') - request.session = self.factory.session - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - ) - rv = client.authorize_redirect(request, 'https://a.b/c') - self.assertEqual(rv.status_code, 302) - url = rv.get('Location') - self.assertIn('state=', url) - state = request.session['_dev_authlib_state_'] - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value(get_bearer_token()) - request = self.factory.get('/authorize?state={}'.format(state)) - request.session = self.factory.session - request.session['_dev_authlib_state_'] = state - - token = client.authorize_access_token(request) - self.assertEqual(token['access_token'], 'a') - - def test_oauth2_authorize_code_challenge(self): - request = self.factory.get('/login') - request.session = self.factory.session - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={'code_challenge_method': 'S256'}, - ) - rv = client.authorize_redirect(request, 'https://a.b/c') - self.assertEqual(rv.status_code, 302) - url = rv.get('Location') - self.assertIn('state=', url) - self.assertIn('code_challenge=', url) - state = request.session['_dev_authlib_state_'] - verifier = request.session['_dev_authlib_code_verifier_'] - - def fake_send(sess, req, **kwargs): - self.assertIn('code_verifier={}'.format(verifier), req.body) - return mock_send_value(get_bearer_token()) - - with mock.patch('requests.sessions.Session.send', fake_send): - request = self.factory.get('/authorize?state={}'.format(state)) - request.session = self.factory.session - request.session['_dev_authlib_state_'] = state - request.session['_dev_authlib_code_verifier_'] = verifier - - token = client.authorize_access_token(request) - self.assertEqual(token['access_token'], 'a') - - def test_oauth2_authorize_code_verifier(self): - request = self.factory.get('/login') - request.session = self.factory.session - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={'code_challenge_method': 'S256'}, - ) - state = 'foo' - code_verifier = 'bar' - rv = client.authorize_redirect( - request, 'https://a.b/c', - state=state, code_verifier=code_verifier - ) - self.assertEqual(rv.status_code, 302) - url = rv.get('Location') - self.assertIn('state=', url) - self.assertIn('code_challenge=', url) - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value(get_bearer_token()) - - request = self.factory.get('/authorize?state={}'.format(state)) - request.session = self.factory.session - request.session['_dev_authlib_state_'] = state - request.session['_dev_authlib_code_verifier_'] = code_verifier - - token = client.authorize_access_token(request) - self.assertEqual(token['access_token'], 'a') - - def test_openid_authorize(self): - request = self.factory.get('/login') - request.session = self.factory.session - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={'scope': 'openid profile'}, - ) - - resp = client.authorize_redirect(request, 'https://b.com/bar') - self.assertEqual(resp.status_code, 302) - nonce = request.session['_dev_authlib_nonce_'] - self.assertIsNotNone(nonce) - url = resp.get('Location') - self.assertIn('nonce={}'.format(nonce), url) - - def test_oauth2_access_token_with_post(self): - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - ) - payload = {'code': 'a', 'state': 'b'} - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value(get_bearer_token()) - request = self.factory.post('/token', data=payload) - request.session = self.factory.session - request.session['_dev_authlib_state_'] = 'b' - token = client.authorize_access_token(request) - self.assertEqual(token['access_token'], 'a') - - def test_with_fetch_token_in_oauth(self): - def fetch_token(name, request): - return {'access_token': name, 'token_type': 'bearer'} - - oauth = OAuth(fetch_token) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - def fake_send(sess, req, **kwargs): - self.assertEqual(sess.token['access_token'], 'dev') - return mock_send_value(get_bearer_token()) - - with mock.patch('requests.sessions.Session.send', fake_send): - request = self.factory.get('/login') - client.get('/user', request=request) - - def test_with_fetch_token_in_register(self): - def fetch_token(request): - return {'access_token': 'dev', 'token_type': 'bearer'} - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - fetch_token=fetch_token, - ) - - def fake_send(sess, req, **kwargs): - self.assertEqual(sess.token['access_token'], 'dev') - return mock_send_value(get_bearer_token()) - - with mock.patch('requests.sessions.Session.send', fake_send): - request = self.factory.get('/login') - client.get('/user', request=request) - - def test_request_without_token(self): - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - def fake_send(sess, req, **kwargs): - auth = req.headers.get('Authorization') - self.assertIsNone(auth) - resp = mock.MagicMock() - resp.text = 'hi' - resp.status_code = 200 - return resp - - with mock.patch('requests.sessions.Session.send', fake_send): - resp = client.get('/api/user', withhold_token=True) - self.assertEqual(resp.text, 'hi') - self.assertRaises(OAuthError, client.get, 'https://i.b/api/user') diff --git a/tests/django/test_oauth1/conftest.py b/tests/django/test_oauth1/conftest.py new file mode 100644 index 000000000..fb526e246 --- /dev/null +++ b/tests/django/test_oauth1/conftest.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from authlib.integrations.django_oauth1 import CacheAuthorizationServer + +from .models import Client +from .models import TokenCredential +from .models import User + +pytestmark = pytest.mark.django_db + + +@pytest.fixture(autouse=True) +def env(): + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" + yield + del os.environ["AUTHLIB_INSECURE_TRANSPORT"] + + +@pytest.fixture +def server(settings): + """Create server that respects current settings.""" + return CacheAuthorizationServer(Client, TokenCredential) + + +@pytest.fixture +def plaintext_server(settings): + """Server configured with PLAINTEXT signature method.""" + settings.AUTHLIB_OAUTH1_PROVIDER = {"signature_methods": ["PLAINTEXT"]} + return CacheAuthorizationServer(Client, TokenCredential) + + +@pytest.fixture +def rsa_server(settings): + """Server configured with RSA-SHA1 signature method.""" + settings.AUTHLIB_OAUTH1_PROVIDER = {"signature_methods": ["RSA-SHA1"]} + return CacheAuthorizationServer(Client, TokenCredential) + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + user.save() + yield user + user.delete() + + +@pytest.fixture(autouse=True) +def client(user, db): + client = Client( + user_id=user.pk, + client_id="client", + client_secret="secret", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() diff --git a/tests/django/test_oauth1/models.py b/tests/django/test_oauth1/models.py index c5ccd0e99..f90aa7484 100644 --- a/tests/django/test_oauth1/models.py +++ b/tests/django/test_oauth1/models.py @@ -1,6 +1,10 @@ -from django.db.models import Model, CharField, TextField -from django.db.models import ForeignKey, CASCADE from django.contrib.auth.models import User +from django.db.models import CASCADE +from django.db.models import CharField +from django.db.models import ForeignKey +from django.db.models import Model +from django.db.models import TextField + from tests.util import read_file_path @@ -8,7 +12,7 @@ class Client(Model): user = ForeignKey(User, on_delete=CASCADE) client_id = CharField(max_length=48, unique=True, db_index=True) client_secret = CharField(max_length=48, blank=True) - default_redirect_uri = TextField(blank=False, default='') + default_redirect_uri = TextField(blank=False, default="") def get_default_redirect_uri(self): return self.default_redirect_uri @@ -17,7 +21,7 @@ def get_client_secret(self): return self.client_secret def get_rsa_public_key(self): - return read_file_path('rsa_public.pem') + return read_file_path("rsa_public.pem") class TokenCredential(Model): diff --git a/tests/django/test_oauth1/oauth1_server.py b/tests/django/test_oauth1/oauth1_server.py deleted file mode 100644 index 6e161239b..000000000 --- a/tests/django/test_oauth1/oauth1_server.py +++ /dev/null @@ -1,13 +0,0 @@ -import os -from authlib.integrations.django_oauth1 import ( - CacheAuthorizationServer, -) -from .models import Client, TokenCredential -from ..base import TestCase as _TestCase - -os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' - - -class TestCase(_TestCase): - def create_server(self): - return CacheAuthorizationServer(Client, TokenCredential) diff --git a/tests/django/test_oauth1/test_authorize.py b/tests/django/test_oauth1/test_authorize.py index a88134655..265e43956 100644 --- a/tests/django/test_oauth1/test_authorize.py +++ b/tests/django/test_oauth1/test_authorize.py @@ -1,142 +1,128 @@ +import pytest + from authlib.oauth1.rfc5849 import errors -from django.test import override_settings from tests.util import decode_response -from .models import User, Client -from .oauth1_server import TestCase - - -class AuthorizationTest(TestCase): - def prepare_data(self): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - client.save() - - def test_invalid_authorization(self): - server = self.create_server() - url = '/oauth/authorize' - request = self.factory.post(url) - self.assertRaises( - errors.MissingRequiredParameterError, - server.check_authorization_request, - request - ) - - request = self.factory.post(url, data={'oauth_token': 'a'}) - self.assertRaises( - errors.InvalidTokenError, - server.check_authorization_request, - request - ) - - def test_invalid_initiate(self): - server = self.create_server() - url = '/oauth/initiate' - request = self.factory.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - resp = server.create_temporary_credentials_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - @override_settings( - AUTHLIB_OAUTH1_PROVIDER={'signature_methods': ['PLAINTEXT']}) - def test_authorize_denied(self): - self.prepare_data() - server = self.create_server() - initiate_url = '/oauth/initiate' - authorize_url = '/oauth/authorize' - - # case 1 - request = self.factory.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - resp = server.create_temporary_credentials_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - - request = self.factory.post(authorize_url, data={ - 'oauth_token': data['oauth_token'] - }) - resp = server.create_authorization_response(request) - self.assertEqual(resp.status_code, 302) - self.assertIn('access_denied', resp['Location']) - self.assertIn('https://a.b', resp['Location']) - - # case 2 - request = self.factory.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'https://i.test', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - resp = server.create_temporary_credentials_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - request = self.factory.post(authorize_url, data={ - 'oauth_token': data['oauth_token'] - }) - resp = server.create_authorization_response(request) - self.assertEqual(resp.status_code, 302) - self.assertIn('access_denied', resp['Location']) - self.assertIn('https://i.test', resp['Location']) - - @override_settings( - AUTHLIB_OAUTH1_PROVIDER={'signature_methods': ['PLAINTEXT']}) - def test_authorize_granted(self): - self.prepare_data() - server = self.create_server() - user = User.objects.get(username='foo') - initiate_url = '/oauth/initiate' - authorize_url = '/oauth/authorize' - - # case 1 - request = self.factory.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - resp = server.create_temporary_credentials_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - - request = self.factory.post(authorize_url, data={ - 'oauth_token': data['oauth_token'] - }) - resp = server.create_authorization_response(request, user) - self.assertEqual(resp.status_code, 302) - - self.assertIn('oauth_verifier', resp['Location']) - self.assertIn('https://a.b', resp['Location']) - - # case 2 - request = self.factory.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'https://i.test', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - resp = server.create_temporary_credentials_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - - request = self.factory.post(authorize_url, data={ - 'oauth_token': data['oauth_token'] - }) - resp = server.create_authorization_response(request, user) - - self.assertEqual(resp.status_code, 302) - self.assertIn('oauth_verifier', resp['Location']) - self.assertIn('https://i.test', resp['Location']) + +from .models import User + + +def test_invalid_authorization(factory, server): + url = "/oauth/authorize" + request = factory.post(url) + with pytest.raises(errors.MissingRequiredParameterError): + server.check_authorization_request(request) + + request = factory.post(url, data={"oauth_token": "a"}) + with pytest.raises(errors.InvalidTokenError): + server.check_authorization_request(request) + + +def test_invalid_initiate(factory, server): + url = "/oauth/initiate" + # Test with non-existent client + request = factory.post( + url, + data={ + "oauth_consumer_key": "nonexistent", # Client doesn't exist + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + resp = server.create_temporary_credentials_response(request) + data = decode_response(resp.content) + assert data["error"] == "invalid_client" + + +def test_authorize_denied(factory, plaintext_server): + server = plaintext_server + initiate_url = "/oauth/initiate" + authorize_url = "/oauth/authorize" + + # case 1 + request = factory.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + resp = server.create_temporary_credentials_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + + request = factory.post(authorize_url, data={"oauth_token": data["oauth_token"]}) + resp = server.create_authorization_response(request) + assert resp.status_code == 302 + assert "access_denied" in resp["Location"] + assert "https://client.test" in resp["Location"] + + # case 2 + request = factory.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "https://i.test", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + resp = server.create_temporary_credentials_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + request = factory.post(authorize_url, data={"oauth_token": data["oauth_token"]}) + resp = server.create_authorization_response(request) + assert resp.status_code == 302 + assert "access_denied" in resp["Location"] + assert "https://i.test" in resp["Location"] + + +def test_authorize_granted(factory, plaintext_server): + server = plaintext_server + user = User.objects.get(username="foo") + initiate_url = "/oauth/initiate" + authorize_url = "/oauth/authorize" + + # case 1 + request = factory.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + resp = server.create_temporary_credentials_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + + request = factory.post(authorize_url, data={"oauth_token": data["oauth_token"]}) + resp = server.create_authorization_response(request, user) + assert resp.status_code == 302 + + assert "oauth_verifier" in resp["Location"] + assert "https://client.test" in resp["Location"] + + # case 2 + request = factory.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "https://i.test", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + resp = server.create_temporary_credentials_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + + request = factory.post(authorize_url, data={"oauth_token": data["oauth_token"]}) + resp = server.create_authorization_response(request, user) + + assert resp.status_code == 302 + assert "oauth_verifier" in resp["Location"] + assert "https://i.test" in resp["Location"] diff --git a/tests/django/test_oauth1/test_resource_protector.py b/tests/django/test_oauth1/test_resource_protector.py index 3466b04bf..1a0a01f3f 100644 --- a/tests/django/test_oauth1/test_resource_protector.py +++ b/tests/django/test_oauth1/test_resource_protector.py @@ -1,189 +1,196 @@ import json import time + +import pytest +from django.http import JsonResponse +from django.test import override_settings + from authlib.common.encoding import to_unicode -from authlib.oauth1.rfc5849 import signature from authlib.common.urls import add_params_to_uri from authlib.integrations.django_oauth1 import ResourceProtector -from django.http import JsonResponse -from django.test import override_settings +from authlib.oauth1.rfc5849 import signature from tests.util import read_file_path -from .models import User, Client, TokenCredential -from .oauth1_server import TestCase - - -class ResourceTest(TestCase): - def create_route(self): - require_oauth = ResourceProtector(Client, TokenCredential) - @require_oauth() - def handle(request): - user = request.oauth1_credential.user - return JsonResponse(dict(username=user.username)) - return handle - - def prepare_data(self): - user = User(username='foo') - user.save() - - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - client.save() - - tok = TokenCredential( - user_id=user.pk, - client_id=client.client_id, - oauth_token='valid-token', - oauth_token_secret='valid-token-secret' - ) - tok.save() - - def test_invalid_request_parameters(self): - self.prepare_data() - handle = self.create_route() - url = '/user' - - # case 1 - request = self.factory.get(url) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_consumer_key', data['error_description']) - - # case 2 - request = self.factory.get( - add_params_to_uri(url, {'oauth_consumer_key': 'a'})) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'invalid_client') - - # case 3 - request = self.factory.get( - add_params_to_uri(url, {'oauth_consumer_key': 'client'})) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_token', data['error_description']) - - # case 4 - request = self.factory.get( - add_params_to_uri(url, { - 'oauth_consumer_key': 'client', - 'oauth_token': 'a' - }) +from .models import Client +from .models import TokenCredential + + +def create_route(): + require_oauth = ResourceProtector(Client, TokenCredential) + + @require_oauth() + def handle(request): + user = request.oauth1_credential.user + return JsonResponse(dict(username=user.username)) + + return handle + + +@pytest.fixture(autouse=True) +def token(user, client, db): + token = TokenCredential( + user_id=user.pk, + client_id=client.client_id, + oauth_token="valid-token", + oauth_token_secret="valid-token-secret", + ) + token.save() + yield token + token.delete() + + +def test_invalid_request_parameters(factory): + handle = create_route() + url = "/user" + + # case 1 + request = factory.get(url) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "missing_required_parameter" + assert "oauth_consumer_key" in data["error_description"] + + # case 2 + request = factory.get(add_params_to_uri(url, {"oauth_consumer_key": "a"})) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "invalid_client" + + # case 3 + request = factory.get(add_params_to_uri(url, {"oauth_consumer_key": "client"})) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "missing_required_parameter" + assert "oauth_token" in data["error_description"] + + # case 4 + request = factory.get( + add_params_to_uri(url, {"oauth_consumer_key": "client", "oauth_token": "a"}) + ) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "invalid_token" + + # case 5 + request = factory.get( + add_params_to_uri( + url, {"oauth_consumer_key": "client", "oauth_token": "valid-token"} ) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'invalid_token') - - # case 5 - request = self.factory.get( - add_params_to_uri(url, { - 'oauth_consumer_key': 'client', - 'oauth_token': 'valid-token' - }) - ) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_timestamp', data['error_description']) - - @override_settings( - AUTHLIB_OAUTH1_PROVIDER={'signature_methods': ['PLAINTEXT']}) - def test_plaintext_signature(self): - self.prepare_data() - handle = self.create_route() - url = '/user' - - # case 1: success - auth_header = ( - 'OAuth oauth_consumer_key="client",' - 'oauth_signature_method="PLAINTEXT",' - 'oauth_token="valid-token",' - 'oauth_signature="secret&valid-token-secret"' - ) - request = self.factory.get(url, HTTP_AUTHORIZATION=auth_header) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertIn('username', data) - - # case 2: invalid signature - auth_header = auth_header.replace('valid-token-secret', 'invalid') - request = self.factory.get(url, HTTP_AUTHORIZATION=auth_header) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'invalid_signature') - - def test_hmac_sha1_signature(self): - self.prepare_data() - handle = self.create_route() - url = '/user' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'valid-token'), - ('oauth_signature_method', 'HMAC-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'hmac-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'GET', 'http://testserver/user', params - ) - sig = signature.hmac_sha1_signature( - base_string, 'secret', 'valid-token-secret') - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - - # case 1: success - request = self.factory.get(url, HTTP_AUTHORIZATION=auth_header) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertIn('username', data) - - # case 2: exists nonce - request = self.factory.get(url, HTTP_AUTHORIZATION=auth_header) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'invalid_nonce') - - @override_settings( - AUTHLIB_OAUTH1_PROVIDER={'signature_methods': ['RSA-SHA1']}) - def test_rsa_sha1_signature(self): - self.prepare_data() - handle = self.create_route() - - url = '/user' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'valid-token'), - ('oauth_signature_method', 'RSA-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'rsa-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'GET', 'http://testserver/user', params - ) - sig = signature.rsa_sha1_signature( - base_string, read_file_path('rsa_private.pem')) - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - - request = self.factory.get(url, HTTP_AUTHORIZATION=auth_header) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertIn('username', data) - - # case: invalid signature - auth_param = auth_param.replace('rsa-sha1-nonce', 'alt-sha1-nonce') - auth_header = 'OAuth ' + auth_param - request = self.factory.get(url, HTTP_AUTHORIZATION=auth_header) - resp = handle(request) - data = json.loads(to_unicode(resp.content)) - self.assertEqual(data['error'], 'invalid_signature') - + ) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "missing_required_parameter" + assert "oauth_timestamp" in data["error_description"] + + +@override_settings(AUTHLIB_OAUTH1_PROVIDER={"signature_methods": ["PLAINTEXT"]}) +def test_plaintext_signature(factory): + handle = create_route() + url = "/user" + + # case 1: success + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_token="valid-token",' + 'oauth_signature="secret&valid-token-secret"' + ) + request = factory.get(url, HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert "username" in data + + # case 2: invalid signature + auth_header = auth_header.replace("valid-token-secret", "invalid") + request = factory.get(url, HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "invalid_signature" + + +def test_hmac_sha1_signature(factory): + handle = create_route() + url = "/user" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "valid-token"), + ("oauth_signature_method", "HMAC-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "hmac-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "GET", "http://testserver/user", params + ) + sig = signature.hmac_sha1_signature(base_string, "secret", "valid-token-secret") + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + + # case 1: success + request = factory.get(url, HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert "username" in data + + # case 2: exists nonce + request = factory.get(url, HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "invalid_nonce" + + +@override_settings(AUTHLIB_OAUTH1_PROVIDER={"signature_methods": ["RSA-SHA1"]}) +def test_rsa_sha1_signature(factory): + handle = create_route() + + url = "/user" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "valid-token"), + ("oauth_signature_method", "RSA-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "rsa-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "GET", "http://testserver/user", params + ) + sig = signature.rsa_sha1_signature(base_string, read_file_path("rsa_private.pem")) + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + + request = factory.get(url, HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert "username" in data + + # case: invalid signature + auth_param = auth_param.replace("rsa-sha1-nonce", "alt-sha1-nonce") + auth_header = "OAuth " + auth_param + request = factory.get(url, HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert data["error"] == "invalid_signature" + + +@override_settings(AUTHLIB_OAUTH1_PROVIDER={"signature_methods": ["PLAINTEXT"]}) +def test_decorator_without_parentheses(factory): + require_oauth = ResourceProtector(Client, TokenCredential) + + @require_oauth + def handle(request): + user = request.oauth1_credential.user + return JsonResponse(dict(username=user.username)) + + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_token="valid-token",' + 'oauth_signature="secret&valid-token-secret"' + ) + request = factory.get("/user", HTTP_AUTHORIZATION=auth_header) + resp = handle(request) + data = json.loads(to_unicode(resp.content)) + assert "username" in data diff --git a/tests/django/test_oauth1/test_token_credentials.py b/tests/django/test_oauth1/test_token_credentials.py index 9e0140e3f..2807b0ede 100644 --- a/tests/django/test_oauth1/test_token_credentials.py +++ b/tests/django/test_oauth1/test_token_credentials.py @@ -1,188 +1,173 @@ import time -from authlib.oauth1.rfc5849 import signature -from tests.util import read_file_path, decode_response -from django.test import override_settings + from django.core.cache import cache -from .models import User, Client -from .oauth1_server import TestCase - - -class AuthorizationTest(TestCase): - def prepare_data(self): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - client.save() - - def prepare_temporary_credential(self, server): - token = { - 'oauth_token': 'abc', - 'oauth_token_secret': 'abc-secret', - 'oauth_verifier': 'abc-verifier', - 'client_id': 'client', - 'user_id': 1 - } - key_prefix = server._temporary_credential_key_prefix - key = key_prefix + token['oauth_token'] - cache.set(key, token, timeout=server._temporary_expires_in) - - def test_invalid_token_request_parameters(self): - self.prepare_data() - server = self.create_server() - url = '/oauth/token' - - # case 1 - request = self.factory.post(url) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_consumer_key', data['error_description']) - - # case 2 - request = self.factory.post(url, data={'oauth_consumer_key': 'a'}) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - # case 3 - request = self.factory.post(url, data={'oauth_consumer_key': 'client'}) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_token', data['error_description']) - - # case 4 - request = self.factory.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_token': 'a' - }) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'invalid_token') - - def test_duplicated_oauth_parameters(self): - self.prepare_data() - server = self.create_server() - url = '/oauth/token?oauth_consumer_key=client' - request = self.factory.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_token': 'abc', - 'oauth_verifier': 'abc' - }) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'duplicated_oauth_protocol_parameter') - - @override_settings( - AUTHLIB_OAUTH1_PROVIDER={'signature_methods': ['PLAINTEXT']}) - def test_plaintext_signature(self): - self.prepare_data() - server = self.create_server() - url = '/oauth/token' - - # case 1: success - self.prepare_temporary_credential(server) - auth_header = ( - 'OAuth oauth_consumer_key="client",' - 'oauth_signature_method="PLAINTEXT",' - 'oauth_token="abc",' - 'oauth_verifier="abc-verifier",' - 'oauth_signature="secret&abc-secret"' - ) - request = self.factory.post(url, HTTP_AUTHORIZATION=auth_header) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - - # case 2: invalid signature - self.prepare_temporary_credential(server) - request = self.factory.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_token': 'abc', - 'oauth_verifier': 'abc-verifier', - 'oauth_signature': 'invalid-signature' - }) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'invalid_signature') - - def test_hmac_sha1_signature(self): - self.prepare_data() - server = self.create_server() - url = '/oauth/token' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'abc'), - ('oauth_verifier', 'abc-verifier'), - ('oauth_signature_method', 'HMAC-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'hmac-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'POST', 'http://testserver/oauth/token', params - ) - sig = signature.hmac_sha1_signature( - base_string, 'secret', 'abc-secret') - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - - # case 1: success - self.prepare_temporary_credential(server) - request = self.factory.post(url, HTTP_AUTHORIZATION=auth_header) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - - # case 2: exists nonce - self.prepare_temporary_credential(server) - request = self.factory.post(url, HTTP_AUTHORIZATION=auth_header) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'invalid_nonce') - - @override_settings( - AUTHLIB_OAUTH1_PROVIDER={'signature_methods': ['RSA-SHA1']}) - def test_rsa_sha1_signature(self): - self.prepare_data() - server = self.create_server() - url = '/oauth/token' - - self.prepare_temporary_credential(server) - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'abc'), - ('oauth_verifier', 'abc-verifier'), - ('oauth_signature_method', 'RSA-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'rsa-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'POST', 'http://testserver/oauth/token', params - ) - sig = signature.rsa_sha1_signature( - base_string, read_file_path('rsa_private.pem')) - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - - request = self.factory.post(url, HTTP_AUTHORIZATION=auth_header) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertIn('oauth_token', data) - - # case: invalid signature - self.prepare_temporary_credential(server) - auth_param = auth_param.replace('rsa-sha1-nonce', 'alt-sha1-nonce') - auth_header = 'OAuth ' + auth_param - request = self.factory.post(url, HTTP_AUTHORIZATION=auth_header) - resp = server.create_token_response(request) - data = decode_response(resp.content) - self.assertEqual(data['error'], 'invalid_signature') +from django.test import override_settings + +from authlib.oauth1.rfc5849 import signature +from tests.util import decode_response +from tests.util import read_file_path + + +def prepare_temporary_credential(server): + token = { + "oauth_token": "abc", + "oauth_token_secret": "abc-secret", + "oauth_verifier": "abc-verifier", + "client_id": "client", + "user_id": 1, + } + key_prefix = server._temporary_credential_key_prefix + key = key_prefix + token["oauth_token"] + cache.set(key, token, timeout=server._temporary_expires_in) + + +def test_invalid_token_request_parameters(factory, server): + url = "/oauth/token" + + # case 1 + request = factory.post(url) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "missing_required_parameter" + assert "oauth_consumer_key" in data["error_description"] + + # case 2 + request = factory.post(url, data={"oauth_consumer_key": "a"}) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "invalid_client" + + # case 3 + request = factory.post(url, data={"oauth_consumer_key": "client"}) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "missing_required_parameter" + assert "oauth_token" in data["error_description"] + + # case 4 + request = factory.post( + url, data={"oauth_consumer_key": "client", "oauth_token": "a"} + ) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "invalid_token" + + +def test_duplicated_oauth_parameters(factory, server): + url = "/oauth/token?oauth_consumer_key=client" + request = factory.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_token": "abc", + "oauth_verifier": "abc", + }, + ) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "duplicated_oauth_protocol_parameter" + + +@override_settings(AUTHLIB_OAUTH1_PROVIDER={"signature_methods": ["PLAINTEXT"]}) +def test_plaintext_signature(factory, server): + url = "/oauth/token" + + # case 1: success + prepare_temporary_credential(server) + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_token="abc",' + 'oauth_verifier="abc-verifier",' + 'oauth_signature="secret&abc-secret"' + ) + request = factory.post(url, HTTP_AUTHORIZATION=auth_header) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + + # case 2: invalid signature + prepare_temporary_credential(server) + request = factory.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_signature_method": "PLAINTEXT", + "oauth_token": "abc", + "oauth_verifier": "abc-verifier", + "oauth_signature": "invalid-signature", + }, + ) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "invalid_signature" + + +def test_hmac_sha1_signature(factory, server): + url = "/oauth/token" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "abc"), + ("oauth_verifier", "abc-verifier"), + ("oauth_signature_method", "HMAC-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "hmac-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "POST", "http://testserver/oauth/token", params + ) + sig = signature.hmac_sha1_signature(base_string, "secret", "abc-secret") + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + + # case 1: success + prepare_temporary_credential(server) + request = factory.post(url, HTTP_AUTHORIZATION=auth_header) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + + # case 2: exists nonce + prepare_temporary_credential(server) + request = factory.post(url, HTTP_AUTHORIZATION=auth_header) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "invalid_nonce" + + +def test_rsa_sha1_signature(factory, rsa_server): + url = "/oauth/token" + server = rsa_server + + prepare_temporary_credential(server) + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "abc"), + ("oauth_verifier", "abc-verifier"), + ("oauth_signature_method", "RSA-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "rsa-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "POST", "http://testserver/oauth/token", params + ) + sig = signature.rsa_sha1_signature(base_string, read_file_path("rsa_private.pem")) + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + + request = factory.post(url, HTTP_AUTHORIZATION=auth_header) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert "oauth_token" in data + + # case: invalid signature + prepare_temporary_credential(server) + auth_param = auth_param.replace("rsa-sha1-nonce", "alt-sha1-nonce") + auth_header = "OAuth " + auth_param + request = factory.post(url, HTTP_AUTHORIZATION=auth_header) + resp = server.create_token_response(request) + data = decode_response(resp.content) + assert data["error"] == "invalid_signature" diff --git a/tests/django/test_oauth2/conftest.py b/tests/django/test_oauth2/conftest.py new file mode 100644 index 000000000..d933d7615 --- /dev/null +++ b/tests/django/test_oauth2/conftest.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from authlib.integrations.django_oauth2 import AuthorizationServer + +from .models import Client +from .models import OAuth2Token +from .models import User + +pytestmark = pytest.mark.django_db + + +@pytest.fixture(autouse=True) +def env(): + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" + yield + os.environ.pop("AUTHLIB_INSECURE_TRANSPORT", None) + + +@pytest.fixture(autouse=True) +def server(settings): + settings.AUTHLIB_OAUTH2_PROVIDER = {} + return AuthorizationServer(Client, OAuth2Token) + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + user.set_password("ok") + user.save() + yield user + user.delete() + + +@pytest.fixture +def token(user): + token = OAuth2Token( + user_id=user.pk, + client_id="client-id", + token_type="bearer", + access_token="a1", + refresh_token="r1", + scope="profile", + expires_in=3600, + ) + token.save() + yield token + token.delete() diff --git a/tests/django/test_oauth2/models.py b/tests/django/test_oauth2/models.py index 00106fd00..4c1533b0b 100644 --- a/tests/django/test_oauth2/models.py +++ b/tests/django/test_oauth2/models.py @@ -1,20 +1,19 @@ import time -from django.db.models import ( - Model, - CharField, - TextField, - BooleanField, - IntegerField, -) -from django.db.models import ForeignKey, CASCADE + from django.contrib.auth.models import User +from django.db.models import CASCADE +from django.db.models import CharField +from django.db.models import ForeignKey +from django.db.models import IntegerField +from django.db.models import Model +from django.db.models import TextField + from authlib.common.security import generate_token -from authlib.oauth2.rfc6749 import ( - ClientMixin, - TokenMixin, - AuthorizationCodeMixin, -) -from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.oauth2.rfc6749 import AuthorizationCodeMixin +from authlib.oauth2.rfc6749 import ClientMixin +from authlib.oauth2.rfc6749 import TokenMixin +from authlib.oauth2.rfc6749.util import list_to_scope +from authlib.oauth2.rfc6749.util import scope_to_list def now_timestamp(): @@ -25,12 +24,12 @@ class Client(Model, ClientMixin): user = ForeignKey(User, on_delete=CASCADE) client_id = CharField(max_length=48, unique=True, db_index=True) client_secret = CharField(max_length=48, blank=True) - redirect_uris = TextField(default='') - default_redirect_uri = TextField(blank=False, default='') - scope = TextField(default='') - response_type = TextField(default='') - grant_type = TextField(default='') - token_endpoint_auth_method = CharField(max_length=120, default='') + redirect_uris = TextField(default="") + default_redirect_uri = TextField(blank=False, default="") + scope = TextField(default="") + response_type = TextField(default="") + grant_type = TextField(default="") + token_endpoint_auth_method = CharField(max_length=120, default="") def get_client_id(self): return self.client_id @@ -40,7 +39,7 @@ def get_default_redirect_uri(self): def get_allowed_scope(self, scope): if not scope: - return '' + return "" allowed = set(scope_to_list(self.scope)) return list_to_scope([s for s in scope.split() if s in allowed]) @@ -49,14 +48,13 @@ def check_redirect_uri(self, redirect_uri): return True return redirect_uri in self.redirect_uris - def has_client_secret(self): - return bool(self.client_secret) - def check_client_secret(self, client_secret): return self.client_secret == client_secret - def check_token_endpoint_auth_method(self, method): - return self.token_endpoint_auth_method == method + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == "token": + return self.token_endpoint_auth_method == method + return True def check_response_type(self, response_type): allowed = self.response_type.split() @@ -73,13 +71,15 @@ class OAuth2Token(Model, TokenMixin): token_type = CharField(max_length=40) access_token = CharField(max_length=255, unique=True, null=False) refresh_token = CharField(max_length=255, db_index=True) - scope = TextField(default='') - revoked = BooleanField(default=False) + scope = TextField(default="") + issued_at = IntegerField(null=False, default=now_timestamp) expires_in = IntegerField(null=False, default=0) + access_token_revoked_at = IntegerField(default=0) + refresh_token_revoked_at = IntegerField(default=0) - def get_client_id(self): - return self.client_id + def check_client(self, client): + return self.client_id == client.client_id def get_scope(self): return self.scope @@ -87,24 +87,27 @@ def get_scope(self): def get_expires_in(self): return self.expires_in - def get_expires_at(self): - return self.issued_at + self.expires_in + def is_revoked(self): + return self.access_token_revoked_at or self.refresh_token_revoked_at - def is_refresh_token_active(self): - if self.revoked: + def is_expired(self): + if not self.expires_in: return False - expired_at = self.issued_at + self.expires_in * 2 - return expired_at >= time.time() + expires_at = self.issued_at + self.expires_in + return expires_at < time.time() + + def is_refresh_token_active(self): + return not self.refresh_token_revoked_at class OAuth2Code(Model, AuthorizationCodeMixin): user = ForeignKey(User, on_delete=CASCADE) client_id = CharField(max_length=48, db_index=True) code = CharField(max_length=120, unique=True, null=False) - redirect_uri = TextField(default='', null=True) - response_type = TextField(default='') - scope = TextField(default='', null=True) + redirect_uri = TextField(default="", null=True) + response_type = TextField(default="") + scope = TextField(default="", null=True) auth_time = IntegerField(null=False, default=now_timestamp) def is_expired(self): @@ -114,13 +117,13 @@ def get_redirect_uri(self): return self.redirect_uri def get_scope(self): - return self.scope or '' + return self.scope or "" def get_auth_time(self): return self.auth_time -class CodeGrantMixin(object): +class CodeGrantMixin: def query_authorization_code(self, code, client): try: item = OAuth2Code.objects.get(code=code, client_id=client.client_id) @@ -142,11 +145,11 @@ def generate_authorization_code(client, grant_user, request, **extra): item = OAuth2Code( code=code, client_id=client.client_id, - redirect_uri=request.redirect_uri, - response_type=request.response_type, - scope=request.scope, + redirect_uri=request.payload.redirect_uri, + response_type=request.payload.response_type, + scope=request.payload.scope, user=grant_user, - **extra + **extra, ) item.save() return code diff --git a/tests/django/test_oauth2/oauth2_server.py b/tests/django/test_oauth2/oauth2_server.py index 6dee878e8..8704ed3f1 100644 --- a/tests/django/test_oauth2/oauth2_server.py +++ b/tests/django/test_oauth2/oauth2_server.py @@ -1,19 +1,10 @@ -import os import base64 -from authlib.common.encoding import to_bytes, to_unicode -from authlib.integrations.django_oauth2 import AuthorizationServer -from .models import Client, OAuth2Token -from ..base import TestCase as _TestCase +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode -os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' - -class TestCase(_TestCase): - def create_server(self): - return AuthorizationServer(Client, OAuth2Token) - - def create_basic_auth(self, username, password): - text = '{}:{}'.format(username, password) - auth = to_unicode(base64.b64encode(to_bytes(text))) - return 'Basic ' + auth +def create_basic_auth(username, password): + text = f"{username}:{password}" + auth = to_unicode(base64.b64encode(to_bytes(text))) + return "Basic " + auth diff --git a/tests/django/test_oauth2/test_authorization_code_grant.py b/tests/django/test_oauth2/test_authorization_code_grant.py index 8d4e580e7..9b39643a5 100644 --- a/tests/django/test_oauth2/test_authorization_code_grant.py +++ b/tests/django/test_oauth2/test_authorization_code_grant.py @@ -1,187 +1,205 @@ import json -from authlib.oauth2.rfc6749 import grants, errors -from authlib.common.urls import urlparse, url_decode -from django.test import override_settings -from .models import User, Client, OAuth2Code +import os + +import pytest + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.oauth2.rfc6749 import errors +from authlib.oauth2.rfc6749 import grants + +from .models import Client from .models import CodeGrantMixin -from .oauth2_server import TestCase - - -class AuthorizationCodeGrant(CodeGrantMixin, grants.AuthorizationCodeGrant): - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] - - def save_authorization_code(self, code, request): - auth_code = OAuth2Code( - code=code, - client_id=request.client.client_id, - redirect_uri=request.redirect_uri, - response_type=request.response_type, - scope=request.scope, - user=request.user, - ) - auth_code.save() - - -class AuthorizationCodeTest(TestCase): - def create_server(self): - server = super(AuthorizationCodeTest, self).create_server() - server.register_grant(AuthorizationCodeGrant) - return server - - def prepare_data(self, response_type='code', grant_type='authorization_code', scope=''): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - response_type=response_type, - grant_type=grant_type, - scope=scope, - token_endpoint_auth_method='client_secret_basic', - default_redirect_uri='https://a.b', - ) - client.save() - - def test_validate_consent_request_client(self): - server = self.create_server() - url = '/authorize?response_type=code' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidClientError, - server.validate_consent_request, - request - ) - - url = '/authorize?response_type=code&client_id=client' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidClientError, - server.validate_consent_request, - request - ) - - self.prepare_data(response_type='') - self.assertRaises( - errors.UnauthorizedClientError, - server.validate_consent_request, - request - ) - - def test_validate_consent_request_redirect_uri(self): - server = self.create_server() - self.prepare_data() - - base_url = '/authorize?response_type=code&client_id=client' - url = base_url + '&redirect_uri=https%3A%2F%2Fa.c' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidRequestError, - server.validate_consent_request, - request - ) - - url = base_url + '&redirect_uri=https%3A%2F%2Fa.b' - request = self.factory.get(url) - grant = server.validate_consent_request(request) - self.assertIsInstance(grant, AuthorizationCodeGrant) - - def test_validate_consent_request_scope(self): - server = self.create_server() - server.metadata = {'scopes_supported': ['profile']} - - self.prepare_data() - base_url = '/authorize?response_type=code&client_id=client' - url = base_url + '&scope=invalid' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidScopeError, - server.validate_consent_request, - request - ) - - def test_create_authorization_response(self): - server = self.create_server() - self.prepare_data() - data = {'response_type': 'code', 'client_id': 'client'} - request = self.factory.post('/authorize', data=data) - server.validate_consent_request(request) - - resp = server.create_authorization_response(request) - self.assertEqual(resp.status_code, 302) - self.assertIn('error=access_denied', resp['Location']) - - grant_user = User.objects.get(username='foo') - resp = server.create_authorization_response(request, grant_user=grant_user) - self.assertEqual(resp.status_code, 302) - self.assertIn('code=', resp['Location']) - - def test_create_token_response_invalid(self): - server = self.create_server() - self.prepare_data() - - # case: no auth - request = self.factory.post('/oauth/token', data={'grant_type': 'authorization_code'}) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - auth_header = self.create_basic_auth('client', 'secret') - - # case: no code - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'authorization_code'}, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - - # case: invalid code - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'authorization_code', 'code': 'invalid'}, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - - def test_create_token_response_success(self): - self.prepare_data() - data = self.get_token_response() - self.assertIn('access_token', data) - self.assertNotIn('refresh_token', data) - - @override_settings( - AUTHLIB_OAUTH2_PROVIDER={'refresh_token_generator': True}) - def test_create_token_response_with_refresh_token(self): - self.prepare_data(grant_type='authorization_code\nrefresh_token') - data = self.get_token_response() - self.assertIn('access_token', data) - self.assertIn('refresh_token', data) - - def get_token_response(self): - server = self.create_server() - data = {'response_type': 'code', 'client_id': 'client'} - request = self.factory.post('/authorize', data=data) - grant_user = User.objects.get(username='foo') - resp = server.create_authorization_response(request, grant_user=grant_user) - self.assertEqual(resp.status_code, 302) - - params = dict(url_decode(urlparse.urlparse(resp['Location']).query)) - code = params['code'] - - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'authorization_code', 'code': code}, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - return data +from .models import OAuth2Code +from .models import User +from .oauth2_server import create_basic_auth + + +@pytest.fixture(autouse=True) +def server(server): + class AuthorizationCodeGrant(CodeGrantMixin, grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + "client_secret_basic", + "client_secret_post", + "none", + ] + + def save_authorization_code(self, code, request): + auth_code = OAuth2Code( + code=code, + client_id=request.client.client_id, + redirect_uri=request.payload.redirect_uri, + response_type=request.payload.response_type, + scope=request.payload.scope, + user=request.user, + ) + auth_code.save() + + server.register_grant(AuthorizationCodeGrant) + return server + + +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + client_secret="client-secret", + response_type="code", + grant_type="authorization_code", + scope="", + token_endpoint_auth_method="client_secret_basic", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() + + +def test_get_consent_grant_client(factory, server, client): + url = "/authorize?response_type=code" + request = factory.get(url) + with pytest.raises(errors.InvalidClientError): + server.get_consent_grant(request) + + url = "/authorize?response_type=code&client_id=invalid-id" + request = factory.get(url) + with pytest.raises(errors.InvalidClientError): + server.get_consent_grant(request) + + client.response_type = "" + client.save() + url = "/authorize?response_type=code&client_id=client-id" + request = factory.get(url) + with pytest.raises(errors.UnauthorizedClientError): + server.get_consent_grant(request) + + url = "/authorize?response_type=code&client_id=client-id&scope=profile&state=bar&redirect_uri=https%3A%2F%2Fclient.test&response_type=code" + request = factory.get(url) + with pytest.raises(errors.InvalidRequestError): + server.get_consent_grant(request) + + +def test_get_consent_grant_redirect_uri(factory, server): + base_url = "/authorize?response_type=code&client_id=client-id" + url = base_url + "&redirect_uri=https%3A%2F%2Fa.c" + request = factory.get(url) + with pytest.raises(errors.InvalidRequestError): + server.get_consent_grant(request) + + url = base_url + "&redirect_uri=https%3A%2F%2Fclient.test" + request = factory.get(url) + grant = server.get_consent_grant(request) + assert isinstance(grant, grants.AuthorizationCodeGrant) + + +def test_get_consent_grant_scope(factory, server): + server.scopes_supported = ["profile"] + base_url = "/authorize?response_type=code&client_id=client-id" + url = base_url + "&scope=invalid" + request = factory.get(url) + with pytest.raises(errors.InvalidScopeError): + server.get_consent_grant(request) + + +def test_create_authorization_response(factory, server): + data = {"response_type": "code", "client_id": "client-id"} + request = factory.post("/authorize", data=data) + grant = server.get_consent_grant(request) + + resp = server.create_authorization_response(request, grant=grant) + assert resp.status_code == 302 + assert "error=access_denied" in resp["Location"] + + grant_user = User.objects.get(username="foo") + resp = server.create_authorization_response( + request, grant=grant, grant_user=grant_user + ) + assert resp.status_code == 302 + assert "code=" in resp["Location"] + + +def test_create_token_response_invalid(factory, server): + # case: no auth + request = factory.post("/oauth/token", data={"grant_type": "authorization_code"}) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + auth_header = create_basic_auth("client-id", "client-secret") + + # case: no code + request = factory.post( + "/oauth/token", + data={"grant_type": "authorization_code"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_request" + + # case: invalid code + request = factory.post( + "/oauth/token", + data={"grant_type": "authorization_code", "code": "invalid"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_grant" + + +def test_create_token_response_success(factory, server): + data = get_token_response(factory, server) + assert "access_token" in data + assert "refresh_token" not in data + + +def test_create_token_response_with_refresh_token(factory, server, client, settings): + settings.AUTHLIB_OAUTH2_PROVIDER["refresh_token_generator"] = True + server.load_config(settings.AUTHLIB_OAUTH2_PROVIDER) + client.grant_type = "authorization_code\nrefresh_token" + client.save() + data = get_token_response(factory, server) + assert "access_token" in data + assert "refresh_token" in data + + +def test_insecure_transport_error_with_payload_access(factory, server): + """Test that InsecureTransportError is raised properly without AttributeError + when accessing request.payload on non-HTTPS requests (issue #795).""" + del os.environ["AUTHLIB_INSECURE_TRANSPORT"] + + request = factory.get( + "https://provider.test/authorize?response_type=code&client_id=client-id" + ) + + with pytest.raises(errors.InsecureTransportError): + server.get_consent_grant(request) + + +def get_token_response(factory, server): + data = {"response_type": "code", "client_id": "client-id"} + request = factory.post("/authorize", data=data) + grant_user = User.objects.get(username="foo") + grant = server.get_consent_grant(request) + resp = server.create_authorization_response( + request, grant=grant, grant_user=grant_user + ) + assert resp.status_code == 302 + + params = dict(url_decode(urlparse.urlparse(resp["Location"]).query)) + code = params["code"] + + request = factory.post( + "/oauth/token", + data={"grant_type": "authorization_code", "code": code}, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + return data diff --git a/tests/django/test_oauth2/test_client_credentials_grant.py b/tests/django/test_oauth2/test_client_credentials_grant.py index b54e0babc..db7280672 100644 --- a/tests/django/test_oauth2/test_client_credentials_grant.py +++ b/tests/django/test_oauth2/test_client_credentials_grant.py @@ -1,100 +1,101 @@ import json + +import pytest + from authlib.oauth2.rfc6749 import grants -from .oauth2_server import TestCase -from .models import User, Client - - -class PasswordTest(TestCase): - def create_server(self): - server = super(PasswordTest, self).create_server() - server.register_grant(grants.ClientCredentialsGrant) - return server - - def prepare_data(self, grant_type='client_credentials', scope=''): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - scope=scope, - grant_type=grant_type, - token_endpoint_auth_method='client_secret_basic', - default_redirect_uri='https://a.b', - ) - client.save() - - def test_invalid_client(self): - server = self.create_server() - self.prepare_data() - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'client_credentials'}, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'client_credentials'}, - HTTP_AUTHORIZATION=self.create_basic_auth('invalid', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - def test_invalid_scope(self): - server = self.create_server() - server.metadata = {'scopes_supported': ['profile']} - self.prepare_data() - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'client_credentials', 'scope': 'invalid'}, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_scope') - - def test_invalid_request(self): - server = self.create_server() - self.prepare_data() - - request = self.factory.get( - '/oauth/token?grant_type=client_credentials', - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'unsupported_grant_type') - - def test_unauthorized_client(self): - server = self.create_server() - self.prepare_data(grant_type='invalid') - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'client_credentials'}, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'unauthorized_client') - - def test_authorize_token(self): - server = self.create_server() - self.prepare_data() - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'client_credentials'}, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertIn('access_token', data) + +from .models import Client +from .oauth2_server import create_basic_auth + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(grants.ClientCredentialsGrant) + return server + + +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + client_secret="client-secret", + scope="", + grant_type="client_credentials", + token_endpoint_auth_method="client_secret_basic", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() + + +def test_invalid_client(factory, server): + request = factory.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + ) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + request = factory.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + HTTP_AUTHORIZATION=create_basic_auth("invalid", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + +def test_invalid_scope(factory, server): + server.scopes_supported = ["profile"] + request = factory.post( + "/oauth/token", + data={"grant_type": "client_credentials", "scope": "invalid"}, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_scope" + + +def test_invalid_request(factory, server): + request = factory.get( + "/oauth/token?grant_type=client_credentials", + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "unsupported_grant_type" + + +def test_unauthorized_client(factory, server, client): + client.grant_type = "invalid" + client.save() + request = factory.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "unauthorized_client" + + +def test_authorize_token(factory, server): + request = factory.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert "access_token" in data diff --git a/tests/django/test_oauth2/test_implicit_grant.py b/tests/django/test_oauth2/test_implicit_grant.py index ef4a16f42..8ac935ee5 100644 --- a/tests/django/test_oauth2/test_implicit_grant.py +++ b/tests/django/test_oauth2/test_implicit_grant.py @@ -1,81 +1,78 @@ -from authlib.oauth2.rfc6749 import grants, errors -from authlib.common.urls import urlparse, url_decode -from .oauth2_server import TestCase -from .models import User, Client - - -class ImplicitTest(TestCase): - def create_server(self): - server = super(ImplicitTest, self).create_server() - server.register_grant(grants.ImplicitGrant) - return server - - def prepare_data(self, response_type='token', scope=''): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - response_type=response_type, - scope=scope, - token_endpoint_auth_method='none', - default_redirect_uri='https://a.b', - ) - client.save() - - def test_validate_consent_request_client(self): - server = self.create_server() - url = '/authorize?response_type=token' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidClientError, - server.validate_consent_request, - request - ) - - url = '/authorize?response_type=token&client_id=client' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidClientError, - server.validate_consent_request, - request - ) - - self.prepare_data(response_type='') - self.assertRaises( - errors.UnauthorizedClientError, - server.validate_consent_request, - request - ) - - def test_validate_consent_request_scope(self): - server = self.create_server() - server.metadata = {'scopes_supported': ['profile']} - - self.prepare_data() - base_url = '/authorize?response_type=token&client_id=client' - url = base_url + '&scope=invalid' - request = self.factory.get(url) - self.assertRaises( - errors.InvalidScopeError, - server.validate_consent_request, - request - ) - - def test_create_authorization_response(self): - server = self.create_server() - self.prepare_data() - data = {'response_type': 'token', 'client_id': 'client'} - request = self.factory.post('/authorize', data=data) - server.validate_consent_request(request) - - resp = server.create_authorization_response(request) - self.assertEqual(resp.status_code, 302) - params = dict(url_decode(urlparse.urlparse(resp['Location']).fragment)) - self.assertEqual(params['error'], 'access_denied') - - grant_user = User.objects.get(username='foo') - resp = server.create_authorization_response(request, grant_user=grant_user) - self.assertEqual(resp.status_code, 302) - params = dict(url_decode(urlparse.urlparse(resp['Location']).fragment)) - self.assertIn('access_token', params) +import pytest + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.oauth2.rfc6749 import errors +from authlib.oauth2.rfc6749 import grants + +from .models import Client +from .models import User + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(grants.ImplicitGrant) + return server + + +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + response_type="token", + scope="", + token_endpoint_auth_method="none", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() + + +def test_get_consent_grant_client(factory, server, client): + url = "/authorize?response_type=token" + request = factory.get(url) + with pytest.raises(errors.InvalidClientError): + server.get_consent_grant(request) + + url = "/authorize?response_type=token&client_id=invalid-id" + request = factory.get(url) + with pytest.raises(errors.InvalidClientError): + server.get_consent_grant(request) + + client.response_type = "" + client.save() + url = "/authorize?response_type=token&client_id=client-id" + request = factory.get(url) + with pytest.raises(errors.UnauthorizedClientError): + server.get_consent_grant(request) + + +def test_get_consent_grant_scope(factory, server): + server.scopes_supported = ["profile"] + + base_url = "/authorize?response_type=token&client_id=client-id" + url = base_url + "&scope=invalid" + request = factory.get(url) + with pytest.raises(errors.InvalidScopeError): + server.get_consent_grant(request) + + +def test_create_authorization_response(factory, server): + data = {"response_type": "token", "client_id": "client-id"} + request = factory.post("/authorize", data=data) + grant = server.get_consent_grant(request) + + resp = server.create_authorization_response(request, grant=grant) + assert resp.status_code == 302 + params = dict(url_decode(urlparse.urlparse(resp["Location"]).fragment)) + assert params["error"] == "access_denied" + + grant_user = User.objects.get(username="foo") + resp = server.create_authorization_response( + request, grant=grant, grant_user=grant_user + ) + assert resp.status_code == 302 + params = dict(url_decode(urlparse.urlparse(resp["Location"]).fragment)) + assert "access_token" in params diff --git a/tests/django/test_oauth2/test_password_grant.py b/tests/django/test_oauth2/test_password_grant.py index 4bb2f71f3..42df2f835 100644 --- a/tests/django/test_oauth2/test_password_grant.py +++ b/tests/django/test_oauth2/test_password_grant.py @@ -1,164 +1,166 @@ import json + +import pytest + from authlib.oauth2.rfc6749.grants import ( ResourceOwnerPasswordCredentialsGrant as _PasswordGrant, ) -from .oauth2_server import TestCase -from .models import User, Client - - -class PasswordGrant(_PasswordGrant): - def authenticate_user(self, username, password): - try: - user = User.objects.get(username=username) - if user.check_password(password): - return user - except User.DoesNotExist: - return None - - -class PasswordTest(TestCase): - def create_server(self): - server = super(PasswordTest, self).create_server() - server.register_grant(PasswordGrant) - return server - - def prepare_data(self, grant_type='password', scope=''): - user = User(username='foo') - user.set_password('ok') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - scope=scope, - grant_type=grant_type, - token_endpoint_auth_method='client_secret_basic', - default_redirect_uri='https://a.b', - ) - client.save() - - def test_invalid_client(self): - server = self.create_server() - self.prepare_data() - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'password', 'username': 'foo', 'password': 'ok'}, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'password', 'username': 'foo', 'password': 'ok'}, - HTTP_AUTHORIZATION=self.create_basic_auth('invalid', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - def test_invalid_scope(self): - server = self.create_server() - server.metadata = {'scopes_supported': ['profile']} - self.prepare_data() - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - 'scope': 'invalid', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_scope') - - def test_invalid_request(self): - server = self.create_server() - self.prepare_data() - auth_header = self.create_basic_auth('client', 'secret') - - # case 1 - request = self.factory.get( - '/oauth/token?grant_type=password', - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'unsupported_grant_type') - - # case 2 - request = self.factory.post( - '/oauth/token', data={'grant_type': 'password'}, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - - # case 3 - request = self.factory.post( - '/oauth/token', data={'grant_type': 'password', 'username': 'foo'}, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - - # case 4 - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'wrong', - }, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - - def test_unauthorized_client(self): - server = self.create_server() - self.prepare_data(grant_type='invalid') - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'unauthorized_client') - - def test_authorize_token(self): - server = self.create_server() - self.prepare_data() - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertIn('access_token', data) +from .models import Client +from .models import User +from .oauth2_server import create_basic_auth + + +@pytest.fixture(autouse=True) +def server(server): + class PasswordGrant(_PasswordGrant): + def authenticate_user(self, username, password): + try: + user = User.objects.get(username=username) + if user.check_password(password): + return user + except User.DoesNotExist: + return None + + server.register_grant(PasswordGrant) + return server + + +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + client_secret="client-secret", + scope="", + grant_type="password", + token_endpoint_auth_method="client_secret_basic", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() + + +def test_invalid_client(factory, server): + request = factory.post( + "/oauth/token", + data={"grant_type": "password", "username": "foo", "password": "ok"}, + ) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + request = factory.post( + "/oauth/token", + data={"grant_type": "password", "username": "foo", "password": "ok"}, + HTTP_AUTHORIZATION=create_basic_auth("invalid", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + +def test_invalid_scope(factory, server): + server.scopes_supported = ["profile"] + request = factory.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + "scope": "invalid", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_scope" + + +def test_invalid_request(factory, server): + auth_header = create_basic_auth("client-id", "client-secret") + + # case 1 + request = factory.get( + "/oauth/token?grant_type=password", + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "unsupported_grant_type" + + # case 2 + request = factory.post( + "/oauth/token", + data={"grant_type": "password"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_request" + + # case 3 + request = factory.post( + "/oauth/token", + data={"grant_type": "password", "username": "foo"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_request" + + # case 4 + request = factory.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "wrong", + }, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_request" + + +def test_unauthorized_client(factory, server, client): + client.grant_type = "invalid" + client.save() + request = factory.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "unauthorized_client" + + +def test_authorize_token(factory, server): + request = factory.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert "access_token" in data diff --git a/tests/django/test_oauth2/test_refresh_token.py b/tests/django/test_oauth2/test_refresh_token.py index db8e48434..398ff9c49 100644 --- a/tests/django/test_oauth2/test_refresh_token.py +++ b/tests/django/test_oauth2/test_refresh_token.py @@ -1,182 +1,163 @@ import json -from authlib.oauth2.rfc6749.grants import ( - RefreshTokenGrant as _RefreshTokenGrant, -) -from .models import User, Client, OAuth2Token -from .oauth2_server import TestCase - - -class RefreshTokenGrant(_RefreshTokenGrant): - def authenticate_refresh_token(self, refresh_token): - try: - item = OAuth2Token.objects.get(refresh_token=refresh_token) - if item.is_refresh_token_active(): - return item - except OAuth2Token.DoesNotExist: - return None - - def authenticate_user(self, credential): - return credential.user - - def revoke_old_credential(self, credential): - credential.revoked = True - credential.save() - return credential - - -class RefreshTokenTest(TestCase): - def create_server(self): - server = super(RefreshTokenTest, self).create_server() - server.register_grant(RefreshTokenGrant) - return server - - def prepare_client(self, grant_type='refresh_token', scope=''): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - scope=scope, - grant_type=grant_type, - token_endpoint_auth_method='client_secret_basic', - default_redirect_uri='https://a.b', - ) - client.save() - - def prepare_token(self, scope='profile', user_id=1): - token = OAuth2Token( - user_id=user_id, - client_id='client', - token_type='bearer', - access_token='a1', - refresh_token='r1', - scope=scope, - expires_in=3600, - ) - token.save() - - def test_invalid_client(self): - server = self.create_server() - self.prepare_client() - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'refresh_token', 'refresh_token': 'foo'}, - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'refresh_token', 'refresh_token': 'foo'}, - HTTP_AUTHORIZATION=self.create_basic_auth('invalid', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - def test_invalid_refresh_token(self): - self.prepare_client() - server = self.create_server() - auth_header = self.create_basic_auth('client', 'secret') - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'refresh_token'}, - HTTP_AUTHORIZATION=auth_header - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - self.assertIn('Missing', data['error_description']) - - request = self.factory.post( - '/oauth/token', - data={'grant_type': 'refresh_token', 'refresh_token': 'invalid'}, - HTTP_AUTHORIZATION=auth_header - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_grant') - - def test_invalid_scope(self): - server = self.create_server() - server.metadata = {'scopes_supported': ['profile']} - self.prepare_client() - self.prepare_token() - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'invalid', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_scope') - - def test_authorize_tno_scope(self): - server = self.create_server() - self.prepare_client() - self.prepare_token() - - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertIn('access_token', data) - - def test_authorize_token_scope(self): - server = self.create_server() - self.prepare_client() - self.prepare_token() - - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertIn('access_token', data) - - def test_revoke_old_token(self): - server = self.create_server() - self.prepare_client() - self.prepare_token() - - request = self.factory.post( - '/oauth/token', - data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'secret'), - ) - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertIn('access_token', data) - - resp = server.create_token_response(request) - self.assertEqual(resp.status_code, 400) +import time + +import pytest + +from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant + +from .models import Client +from .models import OAuth2Token +from .oauth2_server import create_basic_auth + + +@pytest.fixture(autouse=True) +def server(server): + class RefreshTokenGrant(_RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + try: + item = OAuth2Token.objects.get(refresh_token=refresh_token) + if item.is_refresh_token_active(): + return item + except OAuth2Token.DoesNotExist: + return None + + def authenticate_user(self, credential): + return credential.user + + def revoke_old_credential(self, credential): + now = int(time.time()) + credential.access_token_revoked_at = now + credential.refresh_token_revoked_at = now + credential.save() + return credential + + server.register_grant(RefreshTokenGrant) + return server + + +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + client_secret="client-secret", + scope="", + grant_type="refresh_token", + token_endpoint_auth_method="client_secret_basic", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() + + +def test_invalid_client(factory, server): + request = factory.post( + "/oauth/token", + data={"grant_type": "refresh_token", "refresh_token": "foo"}, + ) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + request = factory.post( + "/oauth/token", + data={"grant_type": "refresh_token", "refresh_token": "foo"}, + HTTP_AUTHORIZATION=create_basic_auth("invalid", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + +def test_invalid_refresh_token(factory, server): + auth_header = create_basic_auth("client-id", "client-secret") + request = factory.post( + "/oauth/token", + data={"grant_type": "refresh_token"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_request" + assert "Missing" in data["error_description"] + + request = factory.post( + "/oauth/token", + data={"grant_type": "refresh_token", "refresh_token": "invalid"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_grant" + + +def test_invalid_scope(factory, server, token): + server.scopes_supported = ["profile"] + request = factory.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "invalid", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 400 + data = json.loads(resp.content) + assert data["error"] == "invalid_scope" + + +def test_authorize_tno_scope(factory, server, token): + request = factory.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert "access_token" in data + + +def test_authorize_token_scope(factory, server, token): + request = factory.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert "access_token" in data + + +def test_revoke_old_token(factory, server, token): + request = factory.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + HTTP_AUTHORIZATION=create_basic_auth("client-id", "client-secret"), + ) + resp = server.create_token_response(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert "access_token" in data + + resp = server.create_token_response(request) + assert resp.status_code == 400 diff --git a/tests/django/test_oauth2/test_resource_protector.py b/tests/django/test_oauth2/test_resource_protector.py index f8cabcf71..cde04bd4c 100644 --- a/tests/django/test_oauth2/test_resource_protector.py +++ b/tests/django/test_oauth2/test_resource_protector.py @@ -1,145 +1,141 @@ import json -from authlib.integrations.django_oauth2 import ResourceProtector, BearerTokenValidator + +import pytest from django.http import JsonResponse -from .models import User, Client, OAuth2Token -from .oauth2_server import TestCase +from authlib.integrations.django_oauth2 import BearerTokenValidator +from authlib.integrations.django_oauth2 import ResourceProtector + +from .models import Client +from .models import OAuth2Token require_oauth = ResourceProtector() require_oauth.register_token_validator(BearerTokenValidator(OAuth2Token)) -class ResourceProtectorTest(TestCase): - def prepare_data(self, expires_in=3600, scope='profile'): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - scope='profile', - ) - client.save() - - token = OAuth2Token( - user_id=user.pk, - client_id=client.client_id, - token_type='bearer', - access_token='a1', - scope=scope, - expires_in=expires_in, - ) - token.save() - - def test_invalid_token(self): - @require_oauth('profile') - def get_user_profile(request): - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - - self.prepare_data() - - request = self.factory.get('/user') - resp = get_user_profile(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'missing_authorization') - - request = self.factory.get('/user', HTTP_AUTHORIZATION='invalid token') - resp = get_user_profile(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'unsupported_token_type') - - request = self.factory.get('/user', HTTP_AUTHORIZATION='bearer token') - resp = get_user_profile(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_token') - - def test_expired_token(self): - self.prepare_data(0) - - @require_oauth('profile') - def get_user_profile(request): - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - - request = self.factory.get('/user', HTTP_AUTHORIZATION='bearer a1') - resp = get_user_profile(request) - self.assertEqual(resp.status_code, 401) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_token') - - def test_insufficient_token(self): - self.prepare_data() - - @require_oauth('email') - def get_user_email(request): - user = request.oauth_token.user - return JsonResponse(dict(email=user.email)) - - request = self.factory.get('/user/email', HTTP_AUTHORIZATION='bearer a1') - resp = get_user_email(request) - self.assertEqual(resp.status_code, 403) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'insufficient_scope') - - def test_access_resource(self): - self.prepare_data() - - @require_oauth('profile', optional=True) - def get_user_profile(request): - if request.oauth_token: - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - return JsonResponse(dict(sub=0, username='anonymous')) - - request = self.factory.get('/user') - resp = get_user_profile(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertEqual(data['username'], 'anonymous') - - request = self.factory.get('/user', HTTP_AUTHORIZATION='bearer a1') - resp = get_user_profile(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertEqual(data['username'], 'foo') - - def test_scope_operator(self): - self.prepare_data() - - @require_oauth('profile email', 'AND') - def operator_and(request): +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + client_secret="client-secret", + scope="profile", + ) + client.save() + yield client + client.delete() + + +def test_invalid_token(factory): + @require_oauth("profile") + def get_user_profile(request): + user = request.oauth_token.user + return JsonResponse(dict(sub=user.pk, username=user.username)) + + request = factory.get("/user") + resp = get_user_profile(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "missing_authorization" + + request = factory.get("/user", HTTP_AUTHORIZATION="invalid token") + resp = get_user_profile(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "unsupported_token_type" + + request = factory.get("/user", HTTP_AUTHORIZATION="bearer token") + resp = get_user_profile(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_token" + + +def test_expired_token(factory, token): + token.expires_in = -10 + token.save() + + @require_oauth("profile") + def get_user_profile(request): + user = request.oauth_token.user + return JsonResponse(dict(sub=user.pk, username=user.username)) + + request = factory.get("/user", HTTP_AUTHORIZATION="bearer a1") + resp = get_user_profile(request) + assert resp.status_code == 401 + data = json.loads(resp.content) + assert data["error"] == "invalid_token" + + +def test_insufficient_token(factory, token): + @require_oauth("email") + def get_user_email(request): + user = request.oauth_token.user + return JsonResponse(dict(email=user.email)) + + request = factory.get("/user/email", HTTP_AUTHORIZATION="bearer a1") + resp = get_user_email(request) + assert resp.status_code == 403 + data = json.loads(resp.content) + assert data["error"] == "insufficient_scope" + + +def test_access_resource(factory, token): + @require_oauth("profile", optional=True) + def get_user_profile(request): + if request.oauth_token: user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) - - @require_oauth('profile email', 'OR') - def operator_or(request): - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - - request = self.factory.get('/user', HTTP_AUTHORIZATION='bearer a1') - resp = operator_and(request) - self.assertEqual(resp.status_code, 403) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'insufficient_scope') - - resp = operator_or(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertEqual(data['username'], 'foo') - - def scope_operator(token_scopes, resource_scopes): - return 'profile' in token_scopes and 'email' not in token_scopes - - @require_oauth(operator=scope_operator) - def operator_func(request): - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - - resp = operator_func(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertEqual(data['username'], 'foo') + return JsonResponse(dict(sub=0, username="anonymous")) + + request = factory.get("/user") + resp = get_user_profile(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert data["username"] == "anonymous" + + request = factory.get("/user", HTTP_AUTHORIZATION="bearer a1") + resp = get_user_profile(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert data["username"] == "foo" + + +def test_scope_operator(factory, token): + @require_oauth(["profile email"]) + def operator_and(request): + user = request.oauth_token.user + return JsonResponse(dict(sub=user.pk, username=user.username)) + + @require_oauth(["profile", "email"]) + def operator_or(request): + user = request.oauth_token.user + return JsonResponse(dict(sub=user.pk, username=user.username)) + + request = factory.get("/user", HTTP_AUTHORIZATION="bearer a1") + resp = operator_and(request) + assert resp.status_code == 403 + data = json.loads(resp.content) + assert data["error"] == "insufficient_scope" + + resp = operator_or(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert data["username"] == "foo" + + +def test_decorator_without_parentheses(factory, token): + @require_oauth + def get_resource(request): + user = request.oauth_token.user + return JsonResponse(dict(sub=user.pk, username=user.username)) + + request = factory.get("/resource") + resp = get_resource(request) + assert resp.status_code == 401 + + request = factory.get("/resource", HTTP_AUTHORIZATION="bearer a1") + resp = get_resource(request) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert data["username"] == "foo" diff --git a/tests/django/test_oauth2/test_revocation_endpoint.py b/tests/django/test_oauth2/test_revocation_endpoint.py index 2227f30e0..b1b320926 100644 --- a/tests/django/test_oauth2/test_revocation_endpoint.py +++ b/tests/django/test_oauth2/test_revocation_endpoint.py @@ -1,135 +1,123 @@ import json + +import pytest + from authlib.integrations.django_oauth2 import RevocationEndpoint -from .oauth2_server import TestCase -from .models import User, OAuth2Token, Client +from .models import Client +from .oauth2_server import create_basic_auth ENDPOINT_NAME = RevocationEndpoint.ENDPOINT_NAME -class RevocationEndpointTest(TestCase): - def create_server(self): - server = super(RevocationEndpointTest, self).create_server() - server.register_endpoint(RevocationEndpoint) - return server - - def prepare_client(self): - user = User(username='foo') - user.save() - client = Client( - user_id=user.pk, - client_id='client', - client_secret='secret', - token_endpoint_auth_method='client_secret_basic', - default_redirect_uri='https://a.b', - ) - client.save() - - def prepare_token(self, scope='profile', user_id=1): - token = OAuth2Token( - user_id=user_id, - client_id='client', - token_type='bearer', - access_token='a1', - refresh_token='r1', - scope=scope, - expires_in=3600, - ) - token.save() - - def test_invalid_client(self): - server = self.create_server() - request = self.factory.post('/oauth/revoke') - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - request = self.factory.post('/oauth/revoke', HTTP_AUTHORIZATION='invalid token') - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - request = self.factory.post( - '/oauth/revoke', - HTTP_AUTHORIZATION=self.create_basic_auth('invalid', 'secret'), - ) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - request = self.factory.post( - '/oauth/revoke', - HTTP_AUTHORIZATION=self.create_basic_auth('client', 'invalid'), - ) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_client') - - def test_invalid_token(self): - server = self.create_server() - self.prepare_client() - self.prepare_token() - auth_header = self.create_basic_auth('client', 'secret') - - request = self.factory.post('/oauth/revoke', HTTP_AUTHORIZATION=auth_header) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'invalid_request') - - # case 1 - request = self.factory.post( - '/oauth/revoke', - data={'token': 'invalid-token'}, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - self.assertEqual(resp.status_code, 200) - - # case 2 - request = self.factory.post( - '/oauth/revoke', - data={ - 'token': 'a1', - 'token_type_hint': 'unsupported_token_type', - }, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - data = json.loads(resp.content) - self.assertEqual(data['error'], 'unsupported_token_type') - - # case 3 - request = self.factory.post( - '/oauth/revoke', - data={ - 'token': 'a1', - 'token_type_hint': 'refresh_token', - }, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - self.assertEqual(resp.status_code, 200) - - def test_revoke_token_with_hint(self): - self.prepare_client() - self.prepare_token() - self.revoke_token({'token': 'a1', 'token_type_hint': 'access_token'}) - self.revoke_token({'token': 'r1', 'token_type_hint': 'refresh_token'}) - - def test_revoke_token_without_hint(self): - self.prepare_client() - self.prepare_token() - self.revoke_token({'token': 'a1'}) - self.revoke_token({'token': 'r1'}) - - def revoke_token(self, data): - server = self.create_server() - auth_header = self.create_basic_auth('client', 'secret') - - request = self.factory.post( - '/oauth/revoke', - data=data, - HTTP_AUTHORIZATION=auth_header, - ) - resp = server.create_endpoint_response(ENDPOINT_NAME, request) - self.assertEqual(resp.status_code, 200) +@pytest.fixture(autouse=True) +def server(server): + server.register_endpoint(RevocationEndpoint) + return server + + +@pytest.fixture(autouse=True) +def client(user): + client = Client( + user_id=user.pk, + client_id="client-id", + client_secret="client-secret", + token_endpoint_auth_method="client_secret_basic", + default_redirect_uri="https://client.test", + ) + client.save() + yield client + client.delete() + + +def test_invalid_client(factory, server): + request = factory.post("/oauth/revoke") + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + request = factory.post("/oauth/revoke", HTTP_AUTHORIZATION="invalid token") + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + request = factory.post( + "/oauth/revoke", + HTTP_AUTHORIZATION=create_basic_auth("invalid", "client-secret"), + ) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + request = factory.post( + "/oauth/revoke", + HTTP_AUTHORIZATION=create_basic_auth("client-id", "invalid"), + ) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + data = json.loads(resp.content) + assert data["error"] == "invalid_client" + + +def test_invalid_token(factory, server, token): + auth_header = create_basic_auth("client-id", "client-secret") + + request = factory.post("/oauth/revoke", HTTP_AUTHORIZATION=auth_header) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + data = json.loads(resp.content) + assert data["error"] == "invalid_request" + + # case 1 + request = factory.post( + "/oauth/revoke", + data={"token": "invalid-token"}, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + assert resp.status_code == 200 + + # case 2 + request = factory.post( + "/oauth/revoke", + data={ + "token": "a1", + "token_type_hint": "unsupported_token_type", + }, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + data = json.loads(resp.content) + assert data["error"] == "unsupported_token_type" + + # case 3 + request = factory.post( + "/oauth/revoke", + data={ + "token": "a1", + "token_type_hint": "refresh_token", + }, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + assert resp.status_code == 200 + + +def test_revoke_token_with_hint(factory, server, token): + revoke_token(server, factory, {"token": "a1", "token_type_hint": "access_token"}) + revoke_token(server, factory, {"token": "r1", "token_type_hint": "refresh_token"}) + + +def test_revoke_token_without_hint(factory, server, token): + revoke_token(server, factory, {"token": "a1"}) + revoke_token(server, factory, {"token": "r1"}) + + +def revoke_token(server, factory, data): + auth_header = create_basic_auth("client-id", "client-secret") + + request = factory.post( + "/oauth/revoke", + data=data, + HTTP_AUTHORIZATION=auth_header, + ) + resp = server.create_endpoint_response(ENDPOINT_NAME, request) + assert resp.status_code == 200 diff --git a/tests/django/base.py b/tests/django_helper.py similarity index 77% rename from tests/django/base.py rename to tests/django_helper.py index a218cf503..637e003ff 100644 --- a/tests/django/base.py +++ b/tests/django_helper.py @@ -1,5 +1,5 @@ -from django.test import TestCase as _TestCase, RequestFactory from django.conf import settings +from django.test import RequestFactory from django.utils.module_loading import import_module @@ -15,8 +15,3 @@ def session(self): session.save() self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key return session - - -class TestCase(_TestCase): - def setUp(self): - self.factory = RequestClient() diff --git a/tests/django_settings.py b/tests/django_settings.py new file mode 100644 index 000000000..dba072069 --- /dev/null +++ b/tests/django_settings.py @@ -0,0 +1,39 @@ +SECRET_KEY = "django-secret" + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } +} + +MIDDLEWARE = ["django.contrib.sessions.middleware.SessionMiddleware"] + +SESSION_ENGINE = "django.contrib.sessions.backends.cache" + +CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "unique-snowflake", + } +} + +INSTALLED_APPS = [ + "django.contrib.contenttypes", + "django.contrib.auth", + "tests.django.test_oauth1", + "tests.django.test_oauth2", +] + +AUTHLIB_OAUTH_CLIENTS = { + "dev_overwrite": { + "client_id": "dev-client-id", + "client_secret": "dev-client-secret", + "access_token_params": {"foo": "foo-1", "bar": "bar-2"}, + } +} + +USE_TZ = True + +# Default OAuth1 configuration for tests +AUTHLIB_OAUTH1_PROVIDER = {"signature_methods": ["PLAINTEXT", "HMAC-SHA1"]} diff --git a/tests/files/jwks_single_private.json b/tests/files/jwks_single_private.json new file mode 100644 index 000000000..8a0b33b77 --- /dev/null +++ b/tests/files/jwks_single_private.json @@ -0,0 +1,5 @@ +{ + "keys": [ + {"kty": "RSA", "n": "pF1JaMSN8TEsh4N4O_5SpEAVLivJyLH-Cgl3OQBPGgJkt8cg49oasl-5iJS-VdrILxWM9_JCJyURpUuslX4Eb4eUBtQ0x5BaPa8-S2NLdGTaL7nBOO8o8n0C5FEUU-qlEip79KE8aqOj-OC44VsIquSmOvWIQD26n3fCVlgwoRBD1gzzsDOeaSyzpKrZR851Kh6rEmF2qjJ8jt6EkxMsRNACmBomzgA4M1TTsisSUO87444pe35Z4_n5c735o2fZMrGgMwiJNh7rT8SYxtIkxngioiGnwkxGQxQ4NzPAHg-XSY0J04pNm7KqTkgtxyrqOANJLIjXlR-U9SQ90NjHVQ", "e": "AQAB", "d": "G4E84ppZwm3fLMI0YZ26iJ_sq3BKcRpQD6_r0o8ZrZmO7y4Uc-ywoP7h1lhFzaox66cokuloZpKOdGHIfK-84EkI3WeveWHPqBjmTMlN_ClQVcI48mUbLhD7Zeenhi9y9ipD2fkNWi8OJny8k4GfXrGqm50w8schrsPksnxJjvocGMT6KZNfDURKF2HlM5X1uY8VCofokXOjBEeHIfYM8e7IcmPpyXwXKonDmVVbMbefo-u-TttgeyOYaO6s3flSy6Y0CnpWi43JQ_VEARxQl6Brj1oizr8UnQQ0nNCOWwDNVtOV4eSl7PZoiiT7CxYkYnhJXECMAM5YBpm4Qk9zdQ", "p": "1g4ZGrXOuo75p9_MRIepXGpBWxip4V7B9XmO9WzPCv8nMorJntWBmsYV1I01aITxadHatO4Gl2xLniNkDyrEQzJ7w38RQgsVK-CqbnC0K9N77QPbHeC1YQd9RCNyUohOimKvb7jyv798FBU1GO5QI2eNgfnnfteSVXhD2iOoTOs", "q": "xJJ-8toxJdnLa0uUsAbql6zeNXGbUBMzu3FomKlyuWuq841jS2kIalaO_TRj5hbnE45jmCjeLgTVO6Ach3Wfk4zrqajqfFJ0zUg_Wexp49lC3RWiV4icBb85Q6bzeJD9Dn9vhjpfWVkczf_NeA1fGH_pcgfkT6Dm706GFFttLL8", "dp": "Zfx3l5NR-O8QIhzuHSSp279Afl_E6P0V2phdNa_vAaVKDrmzkHrXcl-4nPnenXrh7vIuiw_xkgnmCWWBUfylYALYlu-e0GGpZ6t2aIJIRa1QmT_CEX0zzhQcae-dk5cgHK0iO0_aUOOyAXuNPeClzAiVknz4ACZDsXdIlNFyaZs", "dq": "Z9DG4xOBKXBhEoWUPXMpqnlN0gPx9tRtWe2HRDkZsfu_CWn-qvEJ1L9qPSfSKs6ls5pb1xyeWseKpjblWlUwtgiS3cOsM4SI03H4o1FMi11PBtxKJNitLgvT_nrJ0z8fpux-xfFGMjXyFImoxmKpepLzg5nPZo6f6HscLNwsSJk", "qi": "Sk20wFvilpRKHq79xxFWiDUPHi0x0pp82dYIEntGQkKUWkbSlhgf3MAi5NEQTDmXdnB-rVeWIvEi-BXfdnNgdn8eC4zSdtF4sIAhYr5VWZo0WVWDhT7u2ccvZBFymiz8lo3gN57wGUCi9pbZqzV1-ZppX6YTNDdDCE0q-KO3Cec"} + ] +} diff --git a/tests/files/jwks_single_public.json b/tests/files/jwks_single_public.json new file mode 100644 index 000000000..c47e1dd8f --- /dev/null +++ b/tests/files/jwks_single_public.json @@ -0,0 +1,5 @@ +{ + "keys": [ + {"kty": "RSA", "kid": "abc", "n": "pF1JaMSN8TEsh4N4O_5SpEAVLivJyLH-Cgl3OQBPGgJkt8cg49oasl-5iJS-VdrILxWM9_JCJyURpUuslX4Eb4eUBtQ0x5BaPa8-S2NLdGTaL7nBOO8o8n0C5FEUU-qlEip79KE8aqOj-OC44VsIquSmOvWIQD26n3fCVlgwoRBD1gzzsDOeaSyzpKrZR851Kh6rEmF2qjJ8jt6EkxMsRNACmBomzgA4M1TTsisSUO87444pe35Z4_n5c735o2fZMrGgMwiJNh7rT8SYxtIkxngioiGnwkxGQxQ4NzPAHg-XSY0J04pNm7KqTkgtxyrqOANJLIjXlR-U9SQ90NjHVQ", "e": "AQAB"} + ] +} diff --git a/tests/files/secp256k1-private.pem b/tests/files/secp256k1-private.pem new file mode 100644 index 000000000..9e1d30ae6 --- /dev/null +++ b/tests/files/secp256k1-private.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgTHXBopHraQcg1U8bPK63 +eO5tNMt5ZcHo/1RsJkSnLAahRANCAAROhceIcao7c/9Ei6PgBLr3+UgDbkxSCJ0d +KDtXgKipXfrI1mVHys/FJ0TzvNPCEZNpPPeWYd/sr5V6ADhdQsHe +-----END PRIVATE KEY----- diff --git a/tests/files/secp256k1-pub.pem b/tests/files/secp256k1-pub.pem new file mode 100644 index 000000000..46faabccd --- /dev/null +++ b/tests/files/secp256k1-pub.pem @@ -0,0 +1,4 @@ +-----BEGIN PUBLIC KEY----- +MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEToXHiHGqO3P/RIuj4AS69/lIA25MUgid +HSg7V4CoqV36yNZlR8rPxSdE87zTwhGTaTz3lmHf7K+VegA4XULB3g== +-----END PUBLIC KEY----- diff --git a/tests/files/ec_private.json b/tests/files/secp521r1-private.json similarity index 100% rename from tests/files/ec_private.json rename to tests/files/secp521r1-private.json diff --git a/tests/files/ec_public.json b/tests/files/secp521r1-public.json similarity index 100% rename from tests/files/ec_public.json rename to tests/files/secp521r1-public.json diff --git a/tests/flask/cache.py b/tests/flask/cache.py index b3c77592a..282e5bc77 100644 --- a/tests/flask/cache.py +++ b/tests/flask/cache.py @@ -1,11 +1,12 @@ import time + try: import cPickle as pickle except ImportError: import pickle -class SimpleCache(object): +class SimpleCache: """A SimpleCache for testing. Copied from Werkzeug.""" def __init__(self, threshold=500, default_timeout=300): @@ -42,9 +43,7 @@ def get(self, key): def set(self, key, value, timeout=None): expires = self._normalize_timeout(timeout) self._prune() - self._cache[key] = ( - expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL) - ) + self._cache[key] = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) return True def delete(self, key): diff --git a/tests/flask/test_client/test_oauth_client.py b/tests/flask/test_client/test_oauth_client.py deleted file mode 100644 index dedf7b9c5..000000000 --- a/tests/flask/test_client/test_oauth_client.py +++ /dev/null @@ -1,411 +0,0 @@ -import mock -from unittest import TestCase -from flask import Flask, session -from authlib.integrations.flask_client import OAuth, OAuthError -from authlib.integrations.flask_client import FlaskRemoteApp -from tests.flask.cache import SimpleCache -from tests.client_base import ( - mock_send_value, - get_bearer_token -) - - -class FlaskOAuthTest(TestCase): - def test_register_remote_app(self): - app = Flask(__name__) - oauth = OAuth(app) - self.assertRaises(AttributeError, lambda: oauth.dev) - - oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - ) - self.assertEqual(oauth.dev.name, 'dev') - self.assertEqual(oauth.dev.client_id, 'dev') - - def test_register_conf_from_app(self): - app = Flask(__name__) - app.config.update({ - 'DEV_CLIENT_ID': 'dev', - 'DEV_CLIENT_SECRET': 'dev', - }) - oauth = OAuth(app) - oauth.register('dev') - self.assertEqual(oauth.dev.client_id, 'dev') - - def test_register_with_overwrite(self): - app = Flask(__name__) - app.config.update({ - 'DEV_CLIENT_ID': 'dev-1', - 'DEV_CLIENT_SECRET': 'dev', - 'DEV_ACCESS_TOKEN_PARAMS': {'foo': 'foo-1'} - }) - oauth = OAuth(app) - oauth.register( - 'dev', overwrite=True, - client_id='dev', - access_token_params={'foo': 'foo'} - ) - self.assertEqual(oauth.dev.client_id, 'dev-1') - self.assertEqual(oauth.dev.client_secret, 'dev') - self.assertEqual(oauth.dev.access_token_params['foo'], 'foo-1') - - def test_init_app_later(self): - app = Flask(__name__) - app.config.update({ - 'DEV_CLIENT_ID': 'dev', - 'DEV_CLIENT_SECRET': 'dev', - }) - oauth = OAuth() - remote = oauth.register('dev') - self.assertRaises(RuntimeError, lambda: oauth.dev.client_id) - oauth.init_app(app) - self.assertEqual(oauth.dev.client_id, 'dev') - self.assertEqual(remote.client_id, 'dev') - - self.assertIsNone(oauth.cache) - self.assertIsNone(oauth.fetch_token) - self.assertIsNone(oauth.update_token) - - def test_init_app_params(self): - app = Flask(__name__) - oauth = OAuth() - oauth.init_app(app, SimpleCache()) - self.assertIsNotNone(oauth.cache) - self.assertIsNone(oauth.update_token) - - oauth.init_app(app, update_token=lambda o: o) - self.assertIsNotNone(oauth.update_token) - - def test_create_client(self): - app = Flask(__name__) - oauth = OAuth(app) - self.assertIsNone(oauth.create_client('dev')) - oauth.register('dev', client_id='dev') - self.assertIsNotNone(oauth.create_client('dev')) - - def test_register_oauth1_remote_app(self): - app = Flask(__name__) - oauth = OAuth(app) - oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - request_token_url='https://i.b/reqeust-token', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - self.assertEqual(oauth.dev.name, 'dev') - self.assertEqual(oauth.dev.client_id, 'dev') - - def test_oauth1_authorize(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app, cache=SimpleCache()) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - request_token_url='https://i.b/reqeust-token', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - with app.test_request_context(): - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value('oauth_token=foo&oauth_verifier=baz') - resp = client.authorize_redirect('https://b.com/bar') - self.assertEqual(resp.status_code, 302) - url = resp.headers.get('Location') - self.assertIn('oauth_token=foo', url) - self.assertIsNotNone(session.get('_dev_authlib_req_token_')) - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value('oauth_token=a&oauth_token_secret=b') - token = client.authorize_access_token() - self.assertEqual(token['oauth_token'], 'a') - - def test_register_oauth2_remote_app(self): - app = Flask(__name__) - oauth = OAuth(app) - oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - refresh_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - update_token=lambda name: 'hi' - ) - self.assertEqual(oauth.dev.name, 'dev') - session = oauth.dev._get_oauth_client() - self.assertIsNotNone(session.update_token) - - def test_oauth2_authorize(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - with app.test_request_context(): - resp = client.authorize_redirect('https://b.com/bar') - self.assertEqual(resp.status_code, 302) - url = resp.headers.get('Location') - self.assertIn('state=', url) - state = session['_dev_authlib_state_'] - self.assertIsNotNone(state) - - with app.test_request_context(path='/?code=a&state={}'.format(state)): - # session is cleared in tests - session['_dev_authlib_state_'] = state - - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value(get_bearer_token()) - token = client.authorize_access_token() - self.assertEqual(token['access_token'], 'a') - - with app.test_request_context(): - self.assertEqual(client.token, None) - - def test_oauth2_authorize_via_custom_client(self): - class CustomRemoteApp(FlaskRemoteApp): - OAUTH_APP_CONFIG = {'authorize_url': 'https://i.b/custom'} - - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - client_cls=CustomRemoteApp, - ) - with app.test_request_context(): - resp = client.authorize_redirect('https://b.com/bar') - self.assertEqual(resp.status_code, 302) - url = resp.headers.get('Location') - self.assertTrue(url.startswith('https://i.b/custom?')) - - def test_oauth2_authorize_with_metadata(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - ) - self.assertRaises(RuntimeError, client.create_authorization_url) - - client = oauth.register( - 'dev2', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - server_metadata_url='https://i.b/.well-known/openid-configuration' - ) - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value({ - 'authorization_endpoint': 'https://i.b/authorize' - }) - - with app.test_request_context(): - resp = client.authorize_redirect('https://b.com/bar') - self.assertEqual(resp.status_code, 302) - - def test_oauth2_authorize_code_challenge(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={'code_challenge_method': 'S256'}, - ) - - with app.test_request_context(): - resp = client.authorize_redirect('https://b.com/bar') - self.assertEqual(resp.status_code, 302) - url = resp.headers.get('Location') - self.assertIn('code_challenge=', url) - self.assertIn('code_challenge_method=S256', url) - state = session['_dev_authlib_state_'] - self.assertIsNotNone(state) - verifier = session['_dev_authlib_code_verifier_'] - self.assertIsNotNone(verifier) - - def fake_send(sess, req, **kwargs): - self.assertIn('code_verifier={}'.format(verifier), req.body) - return mock_send_value(get_bearer_token()) - - path = '/?code=a&state={}'.format(state) - with app.test_request_context(path=path): - # session is cleared in tests - session['_dev_authlib_state_'] = state - session['_dev_authlib_code_verifier_'] = verifier - - with mock.patch('requests.sessions.Session.send', fake_send): - token = client.authorize_access_token() - self.assertEqual(token['access_token'], 'a') - - def test_openid_authorize(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={'scope': 'openid profile'}, - ) - - with app.test_request_context(): - resp = client.authorize_redirect('https://b.com/bar') - self.assertEqual(resp.status_code, 302) - nonce = session['_dev_authlib_nonce_'] - self.assertIsNotNone(nonce) - url = resp.headers.get('Location') - self.assertIn('nonce={}'.format(nonce), url) - - def test_oauth2_access_token_with_post(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - payload = {'code': 'a', 'state': 'b'} - with app.test_request_context(data=payload, method='POST'): - session['_dev_authlib_state_'] = 'b' - with mock.patch('requests.sessions.Session.send') as send: - send.return_value = mock_send_value(get_bearer_token()) - token = client.authorize_access_token() - self.assertEqual(token['access_token'], 'a') - - def test_access_token_with_fetch_token(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth() - - token = get_bearer_token() - oauth.init_app(app, fetch_token=lambda name: token) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - def fake_send(sess, req, **kwargs): - auth = req.headers['Authorization'] - self.assertEqual(auth, 'Bearer {}'.format(token['access_token'])) - resp = mock.MagicMock() - resp.text = 'hi' - resp.status_code = 200 - return resp - - with app.test_request_context(): - with mock.patch('requests.sessions.Session.send', fake_send): - resp = client.get('/api/user') - self.assertEqual(resp.text, 'hi') - - # trigger ctx.authlib_client_oauth_token - resp = client.get('/api/user') - self.assertEqual(resp.text, 'hi') - - def test_request_with_refresh_token(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth() - - expired_token = { - 'token_type': 'Bearer', - 'access_token': 'expired-a', - 'refresh_token': 'expired-b', - 'expires_in': '3600', - 'expires_at': 1566465749, - } - oauth.init_app(app, fetch_token=lambda name: expired_token) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - refresh_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - def fake_send(sess, req, **kwargs): - if req.url == 'https://i.b/token': - auth = req.headers['Authorization'] - self.assertIn('Basic', auth) - resp = mock.MagicMock() - resp.json = get_bearer_token - resp.status_code = 200 - return resp - - resp = mock.MagicMock() - resp.text = 'hi' - resp.status_code = 200 - return resp - - with app.test_request_context(): - with mock.patch('requests.sessions.Session.send', fake_send): - resp = client.get('/api/user', token=expired_token) - self.assertEqual(resp.text, 'hi') - - def test_request_without_token(self): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize' - ) - - def fake_send(sess, req, **kwargs): - auth = req.headers.get('Authorization') - self.assertIsNone(auth) - resp = mock.MagicMock() - resp.text = 'hi' - resp.status_code = 200 - return resp - - with app.test_request_context(): - with mock.patch('requests.sessions.Session.send', fake_send): - resp = client.get('/api/user', withhold_token=True) - self.assertEqual(resp.text, 'hi') - self.assertRaises(OAuthError, client.get, 'https://i.b/api/user') diff --git a/tests/flask/test_client/test_user_mixin.py b/tests/flask/test_client/test_user_mixin.py deleted file mode 100644 index 7b6d25f27..000000000 --- a/tests/flask/test_client/test_user_mixin.py +++ /dev/null @@ -1,175 +0,0 @@ -import mock -from unittest import TestCase -from flask import Flask, session -from authlib.jose import jwk -from authlib.jose.errors import InvalidClaimError -from authlib.integrations.flask_client import OAuth -from authlib.oidc.core.grants.util import generate_id_token -from tests.util import read_file_path -from tests.client_base import ( - get_bearer_token, -) - - -class FlaskUserMixinTest(TestCase): - def run_fetch_userinfo(self, payload, compliance_fix=None): - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - userinfo_endpoint='https://i.b/userinfo', - userinfo_compliance_fix=compliance_fix, - ) - - def fake_send(sess, req, **kwargs): - resp = mock.MagicMock() - resp.json = lambda: payload - resp.status_code = 200 - return resp - - with app.test_request_context(): - with mock.patch('requests.sessions.Session.send', fake_send): - user = client.userinfo() - self.assertEqual(user.sub, '123') - - def test_fetch_userinfo(self): - self.run_fetch_userinfo({'sub': '123'}) - - def test_userinfo_compliance_fix(self): - def _fix(remote, data): - return {'sub': data['id']} - - self.run_fetch_userinfo({'id': '123'}, _fix) - - def test_parse_id_token(self): - key = jwk.dumps('secret', 'oct', kid='f') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123'}, key, - alg='HS256', iss='https://i.b', - aud='dev', exp=3600, nonce='n', - ) - - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - jwks={'keys': [key]}, - issuer='https://i.b', - id_token_signing_alg_values_supported=['HS256', 'RS256'], - ) - with app.test_request_context(): - session['_dev_authlib_nonce_'] = 'n' - self.assertIsNone(client.parse_id_token(token)) - - token['id_token'] = id_token - user = client.parse_id_token(token) - self.assertEqual(user.sub, '123') - - claims_options = {'iss': {'value': 'https://i.b'}} - user = client.parse_id_token(token, claims_options=claims_options) - self.assertEqual(user.sub, '123') - - claims_options = {'iss': {'value': 'https://i.c'}} - self.assertRaises( - InvalidClaimError, - client.parse_id_token, token, claims_options - ) - - def test_parse_id_token_nonce_supported(self): - key = jwk.dumps('secret', 'oct', kid='f') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123', 'nonce_supported': False}, key, - alg='HS256', iss='https://i.b', - aud='dev', exp=3600, - ) - - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - jwks={'keys': [key]}, - issuer='https://i.b', - id_token_signing_alg_values_supported=['HS256', 'RS256'], - ) - with app.test_request_context(): - session['_dev_authlib_nonce_'] = 'n' - token['id_token'] = id_token - user = client.parse_id_token(token) - self.assertEqual(user.sub, '123') - - def test_runtime_error_fetch_jwks_uri(self): - key = jwk.dumps('secret', 'oct', kid='f') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123'}, key, - alg='HS256', iss='https://i.b', - aud='dev', exp=3600, nonce='n', - ) - - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - jwks={'keys': [jwk.dumps('secret', 'oct', kid='b')]}, - issuer='https://i.b', - id_token_signing_alg_values_supported=['HS256'], - ) - with app.test_request_context(): - session['_dev_authlib_nonce_'] = 'n' - token['id_token'] = id_token - self.assertRaises(RuntimeError, client.parse_id_token, token) - - def test_force_fetch_jwks_uri(self): - secret_keys = read_file_path('jwks_private.json') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123'}, secret_keys, - alg='RS256', iss='https://i.b', - aud='dev', exp=3600, nonce='n', - ) - - app = Flask(__name__) - app.secret_key = '!' - oauth = OAuth(app) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - jwks={'keys': [jwk.dumps('secret', 'oct', kid='f')]}, - jwks_uri='https://i.b/jwks', - issuer='https://i.b', - ) - - def fake_send(sess, req, **kwargs): - resp = mock.MagicMock() - resp.json = lambda: read_file_path('jwks_public.json') - resp.status_code = 200 - return resp - - with app.test_request_context(): - session['_dev_authlib_nonce_'] = 'n' - self.assertIsNone(client.parse_id_token(token)) - - with mock.patch('requests.sessions.Session.send', fake_send): - token['id_token'] = id_token - user = client.parse_id_token(token) - self.assertEqual(user.sub, '123') diff --git a/tests/flask/test_oauth1/conftest.py b/tests/flask/test_oauth1/conftest.py new file mode 100644 index 000000000..d72a53a33 --- /dev/null +++ b/tests/flask/test_oauth1/conftest.py @@ -0,0 +1,48 @@ +import os + +import pytest +from flask import Flask + + +@pytest.fixture(autouse=True) +def env(): + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" + yield + del os.environ["AUTHLIB_INSECURE_TRANSPORT"] + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.debug = True + app.testing = True + app.secret_key = "testing" + app.config.update( + { + "OAUTH1_SUPPORTED_SIGNATURE_METHODS": [ + "PLAINTEXT", + "HMAC-SHA1", + "RSA-SHA1", + ], + "SQLALCHEMY_TRACK_MODIFICATIONS": False, + "SQLALCHEMY_DATABASE_URI": "sqlite://", + } + ) + + with app.app_context(): + yield app + + +@pytest.fixture +def test_client(app, db): + return app.test_client() + + +@pytest.fixture +def db(app): + from .oauth1_server import db + + db.init_app(app) + db.create_all() + yield db + db.drop_all() diff --git a/tests/flask/test_oauth1/oauth1_server.py b/tests/flask/test_oauth1/oauth1_server.py index 535d47cee..c70c4d89b 100644 --- a/tests/flask/test_oauth1/oauth1_server.py +++ b/tests/flask/test_oauth1/oauth1_server.py @@ -1,31 +1,23 @@ -import os -import unittest -from flask import Flask, request, jsonify +from flask import jsonify +from flask import request from flask_sqlalchemy import SQLAlchemy + +from authlib.common.urls import url_encode +from authlib.integrations.flask_oauth1 import AuthorizationServer +from authlib.integrations.flask_oauth1 import ResourceProtector from authlib.integrations.flask_oauth1 import ( - AuthorizationServer, ResourceProtector, current_credential -) -from authlib.integrations.sqla_oauth1 import ( - OAuth1ClientMixin, - OAuth1TokenCredentialMixin, - OAuth1TemporaryCredentialMixin, - OAuth1TimestampNonceMixin, - create_query_client_func, - create_query_token_func, - register_authorization_hooks, - create_exists_nonce_func as create_db_exists_nonce_func, -) -from authlib.integrations.flask_oauth1 import ( - register_temporary_credential_hooks, - register_nonce_hooks, create_exists_nonce_func as create_cache_exists_nonce_func, ) +from authlib.integrations.flask_oauth1 import current_credential +from authlib.integrations.flask_oauth1 import register_nonce_hooks +from authlib.integrations.flask_oauth1 import register_temporary_credential_hooks +from authlib.oauth1 import ClientMixin +from authlib.oauth1 import TemporaryCredentialMixin +from authlib.oauth1 import TokenCredentialMixin from authlib.oauth1.errors import OAuth1Error -from authlib.common.urls import url_encode from tests.util import read_file_path -from ..cache import SimpleCache -os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' +from ..cache import SimpleCache db = SQLAlchemy() @@ -38,39 +30,150 @@ def get_user_id(self): return self.id -class Client(db.Model, OAuth1ClientMixin): +class Client(ClientMixin, db.Model): id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') - ) - user = db.relationship('User') + client_id = db.Column(db.String(48), index=True) + client_secret = db.Column(db.String(120), nullable=False) + default_redirect_uri = db.Column(db.Text, nullable=False, default="") + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + def get_default_redirect_uri(self): + return self.default_redirect_uri + + def get_client_secret(self): + return self.client_secret def get_rsa_public_key(self): - return read_file_path('rsa_public.pem') + return read_file_path("rsa_public.pem") -class TokenCredential(db.Model, OAuth1TokenCredentialMixin): +class TokenCredential(TokenCredentialMixin, db.Model): id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') - ) - user = db.relationship('User') + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + client_id = db.Column(db.String(48), index=True) + oauth_token = db.Column(db.String(84), unique=True, index=True) + oauth_token_secret = db.Column(db.String(84)) + + def get_oauth_token(self): + return self.oauth_token + def get_oauth_token_secret(self): + return self.oauth_token_secret -class TemporaryCredential(db.Model, OAuth1TemporaryCredentialMixin): + +class TemporaryCredential(TemporaryCredentialMixin, db.Model): id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') - ) - user = db.relationship('User') + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + client_id = db.Column(db.String(48), index=True) + oauth_token = db.Column(db.String(84), unique=True, index=True) + oauth_token_secret = db.Column(db.String(84)) + oauth_verifier = db.Column(db.String(84)) + oauth_callback = db.Column(db.Text, default="") + + def get_user_id(self): + return self.user_id + + def get_client_id(self): + return self.client_id + + def get_redirect_uri(self): + return self.oauth_callback + def check_verifier(self, verifier): + return self.oauth_verifier == verifier -class TimestampNonce(db.Model, OAuth1TimestampNonceMixin): + def get_oauth_token(self): + return self.oauth_token + + def get_oauth_token_secret(self): + return self.oauth_token_secret + + +class TimestampNonce(db.Model): + __table_args__ = ( + db.UniqueConstraint( + "client_id", "timestamp", "nonce", "oauth_token", name="unique_nonce" + ), + ) id = db.Column(db.Integer, primary_key=True) + client_id = db.Column(db.String(48), nullable=False) + timestamp = db.Column(db.Integer, nullable=False) + nonce = db.Column(db.String(48), nullable=False) + oauth_token = db.Column(db.String(84)) + + +def exists_nonce(nonce, timestamp, client_id, oauth_token): + q = TimestampNonce.query.filter_by( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + ) + if oauth_token: + q = q.filter_by(oauth_token=oauth_token) + rv = q.first() + if rv: + return True + + item = TimestampNonce( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + oauth_token=oauth_token, + ) + db.session.add(item) + db.session.commit() + return False + + +def create_temporary_credential(token, client_id, redirect_uri): + item = TemporaryCredential( + client_id=client_id, + oauth_token=token["oauth_token"], + oauth_token_secret=token["oauth_token_secret"], + oauth_callback=redirect_uri, + ) + db.session.add(item) + db.session.commit() + return item + + +def get_temporary_credential(oauth_token): + return TemporaryCredential.query.filter_by(oauth_token=oauth_token).first() + + +def delete_temporary_credential(oauth_token): + q = TemporaryCredential.query.filter_by(oauth_token=oauth_token) + q.delete(synchronize_session=False) + db.session.commit() + + +def create_authorization_verifier(credential, grant_user, verifier): + credential.user_id = grant_user.id # assuming your end user model has `.id` + credential.oauth_verifier = verifier + db.session.add(credential) + db.session.commit() + return credential + + +def create_token_credential(token, temporary_credential): + credential = TokenCredential( + oauth_token=token["oauth_token"], + oauth_token_secret=token["oauth_token_secret"], + client_id=temporary_credential.get_client_id(), + ) + credential.user_id = temporary_credential.get_user_id() + db.session.add(credential) + db.session.commit() + return credential def create_authorization_server(app, use_cache=False, lazy=False): - query_client = create_query_client_func(db.session, Client) + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + if lazy: server = AuthorizationServer() server.init_app(app, query_client) @@ -80,30 +183,32 @@ def create_authorization_server(app, use_cache=False, lazy=False): cache = SimpleCache() register_nonce_hooks(server, cache) register_temporary_credential_hooks(server, cache) - register_authorization_hooks(server, db.session, TokenCredential) + server.register_hook("create_token_credential", create_token_credential) else: - register_authorization_hooks( - server, db.session, - token_credential_model=TokenCredential, - temporary_credential_model=TemporaryCredential, - timestamp_nonce_model=TimestampNonce, + server.register_hook("exists_nonce", exists_nonce) + server.register_hook("create_temporary_credential", create_temporary_credential) + server.register_hook("get_temporary_credential", get_temporary_credential) + server.register_hook("delete_temporary_credential", delete_temporary_credential) + server.register_hook( + "create_authorization_verifier", create_authorization_verifier ) + server.register_hook("create_token_credential", create_token_credential) - @app.route('/oauth/initiate', methods=['GET', 'POST']) + @app.route("/oauth/initiate", methods=["GET", "POST"]) def initiate(): return server.create_temporary_credentials_response() - @app.route('/oauth/authorize', methods=['GET', 'POST']) + @app.route("/oauth/authorize", methods=["GET", "POST"]) def authorize(): - if request.method == 'GET': + if request.method == "GET": try: server.check_authorization_request() - return 'ok' + return "ok" except OAuth1Error: - return 'error' - user_id = request.form.get('user_id') + return "error" + user_id = request.form.get("user_id") if user_id: - grant_user = User.query.get(int(user_id)) + grant_user = db.session.get(User, int(user_id)) else: grant_user = None try: @@ -111,7 +216,7 @@ def authorize(): except OAuth1Error as error: return url_encode(error.get_body()) - @app.route('/oauth/token', methods=['POST']) + @app.route("/oauth/token", methods=["POST"]) def issue_token(): return server.create_token_response() @@ -123,51 +228,51 @@ def create_resource_server(app, use_cache=False, lazy=False): cache = SimpleCache() exists_nonce = create_cache_exists_nonce_func(cache) else: - exists_nonce = create_db_exists_nonce_func(db.session, TimestampNonce) - query_client = create_query_client_func(db.session, Client) - query_token = create_query_token_func(db.session, TokenCredential) + def exists_nonce(nonce, timestamp, client_id, oauth_token): + q = db.session.query(TimestampNonce.nonce).filter_by( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + ) + if oauth_token: + q = q.filter_by(oauth_token=oauth_token) + rv = q.first() + if rv: + return True + + tn = TimestampNonce( + nonce=nonce, + timestamp=timestamp, + client_id=client_id, + oauth_token=oauth_token, + ) + db.session.add(tn) + db.session.commit() + return False + + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + + def query_token(client_id, oauth_token): + return TokenCredential.query.filter_by( + client_id=client_id, oauth_token=oauth_token + ).first() if lazy: require_oauth = ResourceProtector() require_oauth.init_app(app, query_client, query_token, exists_nonce) else: - require_oauth = ResourceProtector( - app, query_client, query_token, exists_nonce) + require_oauth = ResourceProtector(app, query_client, query_token, exists_nonce) - @app.route('/user') + @app.route("/user") @require_oauth() def user_profile(): user = current_credential.user return jsonify(id=user.id, username=user.username) - -def create_flask_app(): - app = Flask(__name__) - app.debug = True - app.testing = True - app.secret_key = 'testing' - app.config.update({ - 'OAUTH1_SUPPORTED_SIGNATURE_METHODS': ['PLAINTEXT', 'HMAC-SHA1', 'RSA-SHA1'], - 'SQLALCHEMY_TRACK_MODIFICATIONS': False, - 'SQLALCHEMY_DATABASE_URI': 'sqlite://' - }) - return app - - -class TestCase(unittest.TestCase): - def setUp(self): - app = create_flask_app() - - self._ctx = app.app_context() - self._ctx.push() - - db.init_app(app) - db.create_all() - - self.app = app - self.client = app.test_client() - - def tearDown(self): - db.drop_all() - self._ctx.pop() + @app.route("/user-no-parens") + @require_oauth + def user_profile_no_parens(): + user = current_credential.user + return jsonify(id=user.id, username=user.username) diff --git a/tests/flask/test_oauth1/test_authorize.py b/tests/flask/test_oauth1/test_authorize.py index cef927c2e..faa4b585e 100644 --- a/tests/flask/test_oauth1/test_authorize.py +++ b/tests/flask/test_oauth1/test_authorize.py @@ -1,119 +1,133 @@ +import pytest + from tests.util import decode_response -from .oauth1_server import db, User, Client -from .oauth1_server import ( - TestCase, - create_authorization_server, -) - - -class AuthorizationWithCacheTest(TestCase): - USE_CACHE = True - - def prepare_data(self): - create_authorization_server(self.app, self.USE_CACHE, self.USE_CACHE) - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - db.session.add(client) - db.session.commit() - - def test_invalid_authorization(self): - self.prepare_data() - url = '/oauth/authorize' - - # case 1 - rv = self.client.post(url, data={'user_id': '1'}) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_token', data['error_description']) - - # case 2 - rv = self.client.post(url, data={'user_id': '1', 'oauth_token': 'a'}) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_token') - - def test_authorize_denied(self): - self.prepare_data() - initiate_url = '/oauth/initiate' - authorize_url = '/oauth/authorize' - - rv = self.client.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - rv = self.client.post(authorize_url, data={ - 'oauth_token': data['oauth_token'] - }) - self.assertEqual(rv.status_code, 302) - self.assertIn('access_denied', rv.headers['Location']) - self.assertIn('https://a.b', rv.headers['Location']) - - rv = self.client.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'https://i.test', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - rv = self.client.post(authorize_url, data={ - 'oauth_token': data['oauth_token'] - }) - self.assertEqual(rv.status_code, 302) - self.assertIn('access_denied', rv.headers['Location']) - self.assertIn('https://i.test', rv.headers['Location']) - - def test_authorize_granted(self): - self.prepare_data() - initiate_url = '/oauth/initiate' - authorize_url = '/oauth/authorize' - - rv = self.client.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - rv = self.client.post(authorize_url, data={ - 'user_id': '1', - 'oauth_token': data['oauth_token'] - }) - self.assertEqual(rv.status_code, 302) - self.assertIn('oauth_verifier', rv.headers['Location']) - self.assertIn('https://a.b', rv.headers['Location']) - - rv = self.client.post(initiate_url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'https://i.test', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - rv = self.client.post(authorize_url, data={ - 'user_id': '1', - 'oauth_token': data['oauth_token'] - }) - self.assertEqual(rv.status_code, 302) - self.assertIn('oauth_verifier', rv.headers['Location']) - self.assertIn('https://i.test', rv.headers['Location']) - - -class AuthorizationNoCacheTest(AuthorizationWithCacheTest): - USE_CACHE = False + +from .oauth1_server import Client +from .oauth1_server import User +from .oauth1_server import create_authorization_server + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(db, user): + client = Client( + user_id=user.id, + client_id="client", + client_secret="secret", + default_redirect_uri="https://client.test", + ) + db.session.add(client) + db.session.commit() + yield client + db.session.delete(client) + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_invalid_authorization(app, test_client, use_cache): + create_authorization_server(app, use_cache, use_cache) + url = "/oauth/authorize" + + # case 1 + rv = test_client.post(url, data={"user_id": "1"}) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_token" in data["error_description"] + + # case 2 + rv = test_client.post(url, data={"user_id": "1", "oauth_token": "a"}) + data = decode_response(rv.data) + assert data["error"] == "invalid_token" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_authorize_denied(app, test_client, use_cache): + create_authorization_server(app, use_cache, use_cache) + initiate_url = "/oauth/initiate" + authorize_url = "/oauth/authorize" + + rv = test_client.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert "oauth_token" in data + + rv = test_client.post(authorize_url, data={"oauth_token": data["oauth_token"]}) + assert rv.status_code == 302 + assert "access_denied" in rv.headers["Location"] + assert "https://client.test" in rv.headers["Location"] + + rv = test_client.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "https://i.test", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert "oauth_token" in data + + rv = test_client.post(authorize_url, data={"oauth_token": data["oauth_token"]}) + assert rv.status_code == 302 + assert "access_denied" in rv.headers["Location"] + assert "https://i.test" in rv.headers["Location"] + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_authorize_granted(app, test_client, use_cache): + create_authorization_server(app, use_cache, use_cache) + initiate_url = "/oauth/initiate" + authorize_url = "/oauth/authorize" + + rv = test_client.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert "oauth_token" in data + + rv = test_client.post( + authorize_url, data={"user_id": "1", "oauth_token": data["oauth_token"]} + ) + assert rv.status_code == 302 + assert "oauth_verifier" in rv.headers["Location"] + assert "https://client.test" in rv.headers["Location"] + + rv = test_client.post( + initiate_url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "https://i.test", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert "oauth_token" in data + + rv = test_client.post( + authorize_url, data={"user_id": "1", "oauth_token": data["oauth_token"]} + ) + assert rv.status_code == 302 + assert "oauth_verifier" in rv.headers["Location"] + assert "https://i.test" in rv.headers["Location"] diff --git a/tests/flask/test_oauth1/test_resource_protector.py b/tests/flask/test_oauth1/test_resource_protector.py index 87c0e5c42..f85547d5a 100644 --- a/tests/flask/test_oauth1/test_resource_protector.py +++ b/tests/flask/test_oauth1/test_resource_protector.py @@ -1,172 +1,192 @@ import time + +import pytest from flask import json -from authlib.oauth1.rfc5849 import signature + from authlib.common.urls import add_params_to_uri +from authlib.oauth1.rfc5849 import signature from tests.util import read_file_path -from .oauth1_server import db, User, Client, TokenCredential -from .oauth1_server import ( - TestCase, - create_resource_server, -) - - -class ResourceCacheTest(TestCase): - USE_CACHE = True - - def prepare_data(self): - create_resource_server(self.app, self.USE_CACHE, self.USE_CACHE) - user = User(username='foo') - db.session.add(user) - db.session.commit() - - client = Client( - user_id=user.id, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - db.session.add(client) - db.session.commit() - - tok = TokenCredential( - user_id=user.id, - client_id=client.client_id, - oauth_token='valid-token', - oauth_token_secret='valid-token-secret' - ) - db.session.add(tok) - db.session.commit() - - def test_invalid_request_parameters(self): - self.prepare_data() - url = '/user' - - # case 1 - rv = self.client.get(url) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_consumer_key', data['error_description']) - - # case 2 - rv = self.client.get( - add_params_to_uri(url, {'oauth_consumer_key': 'a'})) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'invalid_client') - - # case 3 - rv = self.client.get( - add_params_to_uri(url, {'oauth_consumer_key': 'client'})) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_token', data['error_description']) - - # case 4 - rv = self.client.get( - add_params_to_uri(url, { - 'oauth_consumer_key': 'client', - 'oauth_token': 'a' - }) - ) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'invalid_token') - - # case 5 - rv = self.client.get( - add_params_to_uri(url, { - 'oauth_consumer_key': 'client', - 'oauth_token': 'valid-token' - }) - ) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_timestamp', data['error_description']) - - def test_plaintext_signature(self): - self.prepare_data() - url = '/user' - - # case 1: success - auth_header = ( - 'OAuth oauth_consumer_key="client",' - 'oauth_signature_method="PLAINTEXT",' - 'oauth_token="valid-token",' - 'oauth_signature="secret&valid-token-secret"' - ) - headers = {'Authorization': auth_header} - rv = self.client.get(url, headers=headers) - data = json.loads(rv.data) - self.assertIn('username', data) - - # case 2: invalid signature - auth_header = auth_header.replace('valid-token-secret', 'invalid') - headers = {'Authorization': auth_header} - rv = self.client.get(url, headers=headers) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'invalid_signature') - - def test_hmac_sha1_signature(self): - self.prepare_data() - url = '/user' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'valid-token'), - ('oauth_signature_method', 'HMAC-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'hmac-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'GET', 'http://localhost/user', params - ) - sig = signature.hmac_sha1_signature( - base_string, 'secret', 'valid-token-secret') - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - - # case 1: success - rv = self.client.get(url, headers=headers) - data = json.loads(rv.data) - self.assertIn('username', data) - - # case 2: exists nonce - rv = self.client.get(url, headers=headers) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'invalid_nonce') - - def test_rsa_sha1_signature(self): - self.prepare_data() - url = '/user' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'valid-token'), - ('oauth_signature_method', 'RSA-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'rsa-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'GET', 'http://localhost/user', params + +from .oauth1_server import Client +from .oauth1_server import TokenCredential +from .oauth1_server import User +from .oauth1_server import create_resource_server + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(db, user): + client = Client( + user_id=user.id, + client_id="client", + client_secret="secret", + default_redirect_uri="https://client.test", + ) + db.session.add(client) + db.session.commit() + yield client + db.session.delete(client) + + +@pytest.fixture(autouse=True) +def token(db, user, client): + tok = TokenCredential( + user_id=user.id, + client_id=client.client_id, + oauth_token="valid-token", + oauth_token_secret="valid-token-secret", + ) + db.session.add(tok) + db.session.commit() + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_invalid_request_parameters(app, test_client, use_cache): + create_resource_server(app, use_cache, use_cache) + url = "/user" + + # case 1 + rv = test_client.get(url) + data = json.loads(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_consumer_key" in data["error_description"] + + # case 2 + rv = test_client.get(add_params_to_uri(url, {"oauth_consumer_key": "a"})) + data = json.loads(rv.data) + assert data["error"] == "invalid_client" + + # case 3 + rv = test_client.get(add_params_to_uri(url, {"oauth_consumer_key": "client"})) + data = json.loads(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_token" in data["error_description"] + + # case 4 + rv = test_client.get( + add_params_to_uri(url, {"oauth_consumer_key": "client", "oauth_token": "a"}) + ) + data = json.loads(rv.data) + assert data["error"] == "invalid_token" + + # case 5 + rv = test_client.get( + add_params_to_uri( + url, {"oauth_consumer_key": "client", "oauth_token": "valid-token"} ) - sig = signature.rsa_sha1_signature( - base_string, read_file_path('rsa_private.pem')) - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - rv = self.client.get(url, headers=headers) - data = json.loads(rv.data) - self.assertIn('username', data) - - # case: invalid signature - auth_param = auth_param.replace('rsa-sha1-nonce', 'alt-sha1-nonce') - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - rv = self.client.get(url, headers=headers) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'invalid_signature') - - -class ResourceDBTest(ResourceCacheTest): - USE_CACHE = False + ) + data = json.loads(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_timestamp" in data["error_description"] + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_plaintext_signature(app, test_client, use_cache): + create_resource_server(app, use_cache, use_cache) + url = "/user" + + # case 1: success + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_token="valid-token",' + 'oauth_signature="secret&valid-token-secret"' + ) + headers = {"Authorization": auth_header} + rv = test_client.get(url, headers=headers) + data = json.loads(rv.data) + assert "username" in data + + # case 2: invalid signature + auth_header = auth_header.replace("valid-token-secret", "invalid") + headers = {"Authorization": auth_header} + rv = test_client.get(url, headers=headers) + data = json.loads(rv.data) + assert data["error"] == "invalid_signature" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_hmac_sha1_signature(app, test_client, use_cache): + create_resource_server(app, use_cache, use_cache) + url = "/user" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "valid-token"), + ("oauth_signature_method", "HMAC-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "hmac-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "GET", "http://localhost/user", params + ) + sig = signature.hmac_sha1_signature(base_string, "secret", "valid-token-secret") + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + + # case 1: success + rv = test_client.get(url, headers=headers) + data = json.loads(rv.data) + assert "username" in data + + # case 2: exists nonce + rv = test_client.get(url, headers=headers) + data = json.loads(rv.data) + assert data["error"] == "invalid_nonce" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_rsa_sha1_signature(app, test_client, use_cache): + create_resource_server(app, use_cache, use_cache) + url = "/user" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "valid-token"), + ("oauth_signature_method", "RSA-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "rsa-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "GET", "http://localhost/user", params + ) + sig = signature.rsa_sha1_signature(base_string, read_file_path("rsa_private.pem")) + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + rv = test_client.get(url, headers=headers) + data = json.loads(rv.data) + assert "username" in data + + # case: invalid signature + auth_param = auth_param.replace("rsa-sha1-nonce", "alt-sha1-nonce") + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + rv = test_client.get(url, headers=headers) + data = json.loads(rv.data) + assert data["error"] == "invalid_signature" + + +def test_decorator_without_parentheses(app, test_client): + create_resource_server(app) + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_token="valid-token",' + 'oauth_signature="secret&valid-token-secret"' + ) + headers = {"Authorization": auth_header} + rv = test_client.get("/user-no-parens", headers=headers) + data = json.loads(rv.data) + assert "username" in data diff --git a/tests/flask/test_oauth1/test_temporary_credentials.py b/tests/flask/test_oauth1/test_temporary_credentials.py index 888b7fd86..2084178eb 100644 --- a/tests/flask/test_oauth1/test_temporary_credentials.py +++ b/tests/flask/test_oauth1/test_temporary_credentials.py @@ -1,287 +1,334 @@ import time + +import pytest + from authlib.oauth1.rfc5849 import signature -from tests.util import read_file_path, decode_response -from .oauth1_server import db, User, Client -from .oauth1_server import ( - TestCase, - create_authorization_server, -) - - -class TemporaryCredentialsWithCacheTest(TestCase): - USE_CACHE = True - - def prepare_data(self): - self.server = create_authorization_server(self.app, self.USE_CACHE) - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - db.session.add(client) - db.session.commit() - - def test_temporary_credential_parameters_errors(self): - self.prepare_data() - url = '/oauth/initiate' - - rv = self.client.get(url) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'method_not_allowed') - - # case 1 - rv = self.client.post(url) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_consumer_key', data['error_description']) - - # case 2 - rv = self.client.post(url, data={'oauth_consumer_key': 'client'}) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_callback', data['error_description']) - - # case 3 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'invalid_url' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_request') - self.assertIn('oauth_callback', data['error_description']) - - # case 4 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'invalid-client', - 'oauth_callback': 'oob' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_client') - - def test_validate_timestamp_and_nonce(self): - self.prepare_data() - url = '/oauth/initiate' - - # case 5 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_timestamp', data['error_description']) - - # case 6 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_timestamp': str(int(time.time())) - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_nonce', data['error_description']) - - # case 7 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_timestamp': '123' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_request') - self.assertIn('oauth_timestamp', data['error_description']) - - # case 8 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_timestamp': 'sss' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_request') - self.assertIn('oauth_timestamp', data['error_description']) - - # case 9 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_timestamp': '-1', - 'oauth_signature_method': 'PLAINTEXT' - }) - self.assertEqual(data['error'], 'invalid_request') - self.assertIn('oauth_timestamp', data['error_description']) - - def test_temporary_credential_signatures_errors(self): - self.prepare_data() - url = '/oauth/initiate' - - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_signature', data['error_description']) - - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_timestamp': str(int(time.time())), - 'oauth_nonce': 'a' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_signature_method', data['error_description']) - - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_signature_method': 'INVALID', - 'oauth_callback': 'oob', - 'oauth_timestamp': str(int(time.time())), - 'oauth_nonce': 'b', - 'oauth_signature': 'c' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'unsupported_signature_method') - - def test_plaintext_signature(self): - self.prepare_data() - url = '/oauth/initiate' - - # case 1: use payload - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case 2: use header - auth_header = ( - 'OAuth oauth_consumer_key="client",' - 'oauth_signature_method="PLAINTEXT",' - 'oauth_callback="oob",' - 'oauth_signature="secret&"' - ) - headers = {'Authorization': auth_header} - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case 3: invalid signature - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'invalid-signature' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_signature') - - def test_hmac_sha1_signature(self): - self.prepare_data() - url = '/oauth/initiate' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_callback', 'oob'), - ('oauth_signature_method', 'HMAC-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'hmac-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'POST', 'http://localhost/oauth/initiate', params - ) - sig = signature.hmac_sha1_signature(base_string, 'secret', None) - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - - # case 1: success - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case 2: exists nonce - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_nonce') - - def test_rsa_sha1_signature(self): - self.prepare_data() - url = '/oauth/initiate' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_callback', 'oob'), - ('oauth_signature_method', 'RSA-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'rsa-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'POST', 'http://localhost/oauth/initiate', params - ) - sig = signature.rsa_sha1_signature( - base_string, read_file_path('rsa_private.pem')) - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case: invalid signature - auth_param = auth_param.replace('rsa-sha1-nonce', 'alt-sha1-nonce') - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_signature') - - def test_invalid_signature(self): - self.app.config.update({ - 'OAUTH1_SUPPORTED_SIGNATURE_METHODS': ['INVALID'] - }) - self.prepare_data() - url = '/oauth/initiate' - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'unsupported_signature_method') - - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_callback': 'oob', - 'oauth_signature_method': 'INVALID', - 'oauth_timestamp': str(int(time.time())), - 'oauth_nonce': 'invalid-nonce', - 'oauth_signature': 'secret&' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'unsupported_signature_method') - - def test_register_signature_method(self): - self.prepare_data() - - def foo(): - pass - - self.server.register_signature_method('foo', foo) - self.assertEqual(self.server.SIGNATURE_METHODS['foo'], foo) - - -class TemporaryCredentialsNoCacheTest(TemporaryCredentialsWithCacheTest): - USE_CACHE = False +from tests.util import decode_response +from tests.util import read_file_path + +from .oauth1_server import Client +from .oauth1_server import User +from .oauth1_server import create_authorization_server + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(db, user): + client = Client( + user_id=user.id, + client_id="client", + client_secret="secret", + default_redirect_uri="https://client.test", + ) + db.session.add(client) + db.session.commit() + yield db + db.session.delete(client) + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_temporary_credential_parameters_errors(app, test_client, use_cache): + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + + rv = test_client.get(url) + data = decode_response(rv.data) + assert data["error"] == "method_not_allowed" + + # case 1 + rv = test_client.post(url) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_consumer_key" in data["error_description"] + + # case 2 + rv = test_client.post(url, data={"oauth_consumer_key": "client"}) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_callback" in data["error_description"] + + # case 3 + rv = test_client.post( + url, data={"oauth_consumer_key": "client", "oauth_callback": "invalid_url"} + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_request" + assert "oauth_callback" in data["error_description"] + + # case 4 + rv = test_client.post( + url, data={"oauth_consumer_key": "invalid-client", "oauth_callback": "oob"} + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_client" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_validate_timestamp_and_nonce(app, test_client, use_cache): + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + + # case 5 + rv = test_client.post( + url, data={"oauth_consumer_key": "client", "oauth_callback": "oob"} + ) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_timestamp" in data["error_description"] + + # case 6 + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_timestamp": str(int(time.time())), + }, + ) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_nonce" in data["error_description"] + + # case 7 + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_timestamp": "123", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_request" + assert "oauth_timestamp" in data["error_description"] + + # case 8 + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_timestamp": "sss", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_request" + assert "oauth_timestamp" in data["error_description"] + + # case 9 + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_timestamp": "-1", + "oauth_signature_method": "PLAINTEXT", + }, + ) + assert data["error"] == "invalid_request" + assert "oauth_timestamp" in data["error_description"] + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_temporary_credential_signatures_errors(app, test_client, use_cache): + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_signature" in data["error_description"] + + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_timestamp": str(int(time.time())), + "oauth_nonce": "a", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_signature_method" in data["error_description"] + + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_signature_method": "INVALID", + "oauth_callback": "oob", + "oauth_timestamp": str(int(time.time())), + "oauth_nonce": "b", + "oauth_signature": "c", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "unsupported_signature_method" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_plaintext_signature(app, test_client, use_cache): + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + + # case 1: use payload + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case 2: use header + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_callback="oob",' + 'oauth_signature="secret&"' + ) + headers = {"Authorization": auth_header} + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case 3: invalid signature + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "invalid-signature", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_signature" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_hmac_sha1_signature(app, test_client, use_cache): + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_callback", "oob"), + ("oauth_signature_method", "HMAC-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "hmac-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "POST", "http://localhost/oauth/initiate", params + ) + sig = signature.hmac_sha1_signature(base_string, "secret", None) + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + + # case 1: success + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case 2: exists nonce + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert data["error"] == "invalid_nonce" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_rsa_sha1_signature(app, test_client, use_cache): + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_callback", "oob"), + ("oauth_signature_method", "RSA-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "rsa-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "POST", "http://localhost/oauth/initiate", params + ) + sig = signature.rsa_sha1_signature(base_string, read_file_path("rsa_private.pem")) + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case: invalid signature + auth_param = auth_param.replace("rsa-sha1-nonce", "alt-sha1-nonce") + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert data["error"] == "invalid_signature" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_invalid_signature(app, test_client, use_cache): + app.config.update({"OAUTH1_SUPPORTED_SIGNATURE_METHODS": ["INVALID"]}) + create_authorization_server(app, use_cache) + url = "/oauth/initiate" + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "PLAINTEXT", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "unsupported_signature_method" + + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_callback": "oob", + "oauth_signature_method": "INVALID", + "oauth_timestamp": str(int(time.time())), + "oauth_nonce": "invalid-nonce", + "oauth_signature": "secret&", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "unsupported_signature_method" + + +@pytest.mark.parametrize("use_cache", [True, False]) +def test_register_signature_method(app, test_client, use_cache): + server = create_authorization_server(app, use_cache) + + def foo(): + pass + + server.register_signature_method("foo", foo) + assert server.SIGNATURE_METHODS["foo"] == foo diff --git a/tests/flask/test_oauth1/test_token_credentials.py b/tests/flask/test_oauth1/test_token_credentials.py index 3f86b909f..eae43e89f 100644 --- a/tests/flask/test_oauth1/test_token_credentials.py +++ b/tests/flask/test_oauth1/test_token_credentials.py @@ -1,207 +1,221 @@ import time + +import pytest + from authlib.oauth1.rfc5849 import signature -from tests.util import read_file_path, decode_response -from .oauth1_server import db, User, Client -from .oauth1_server import ( - TestCase, - create_authorization_server, -) - - -class TokenCredentialsTest(TestCase): - USE_CACHE = True - - def prepare_data(self): - self.server = create_authorization_server(self.app, self.USE_CACHE) - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='client', - client_secret='secret', - default_redirect_uri='https://a.b', - ) - db.session.add(client) - db.session.commit() - - def prepare_temporary_credential(self): - credential = { - 'oauth_token': 'abc', - 'oauth_token_secret': 'abc-secret', - 'oauth_verifier': 'abc-verifier', - 'user': 1 - } - func = self.server._hooks['create_temporary_credential'] - func(credential, 'client', 'oob') - - def test_invalid_token_request_parameters(self): - self.prepare_data() - url = '/oauth/token' - - # case 1 - rv = self.client.post(url) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_consumer_key', data['error_description']) - - # case 2 - rv = self.client.post(url, data={'oauth_consumer_key': 'a'}) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_client') - - # case 3 - rv = self.client.post(url, data={'oauth_consumer_key': 'client'}) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_token', data['error_description']) - - # case 4 - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_token': 'a' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_token') - - def test_invalid_token_and_verifiers(self): - self.prepare_data() - url = '/oauth/token' - hook = self.server._hooks['create_temporary_credential'] - - # case 5 - hook( - {'oauth_token': 'abc', 'oauth_token_secret': 'abc-secret'}, - 'client', 'oob' - ) - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_token': 'abc' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'missing_required_parameter') - self.assertIn('oauth_verifier', data['error_description']) - - # case 6 - hook( - {'oauth_token': 'abc', 'oauth_token_secret': 'abc-secret'}, - 'client', 'oob' - ) - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_token': 'abc', - 'oauth_verifier': 'abc' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_request') - self.assertIn('oauth_verifier', data['error_description']) - - def test_duplicated_oauth_parameters(self): - self.prepare_data() - url = '/oauth/token?oauth_consumer_key=client' - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_token': 'abc', - 'oauth_verifier': 'abc' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'duplicated_oauth_protocol_parameter') - - def test_plaintext_signature(self): - self.prepare_data() - url = '/oauth/token' - - # case 1: success - self.prepare_temporary_credential() - auth_header = ( - 'OAuth oauth_consumer_key="client",' - 'oauth_signature_method="PLAINTEXT",' - 'oauth_token="abc",' - 'oauth_verifier="abc-verifier",' - 'oauth_signature="secret&abc-secret"' - ) - headers = {'Authorization': auth_header} - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case 2: invalid signature - self.prepare_temporary_credential() - rv = self.client.post(url, data={ - 'oauth_consumer_key': 'client', - 'oauth_signature_method': 'PLAINTEXT', - 'oauth_token': 'abc', - 'oauth_verifier': 'abc-verifier', - 'oauth_signature': 'invalid-signature' - }) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_signature') - - def test_hmac_sha1_signature(self): - self.prepare_data() - url = '/oauth/token' - - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'abc'), - ('oauth_verifier', 'abc-verifier'), - ('oauth_signature_method', 'HMAC-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'hmac-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'POST', 'http://localhost/oauth/token', params - ) - sig = signature.hmac_sha1_signature( - base_string, 'secret', 'abc-secret') - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - - # case 1: success - self.prepare_temporary_credential() - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case 2: exists nonce - self.prepare_temporary_credential() - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_nonce') - - def test_rsa_sha1_signature(self): - self.prepare_data() - url = '/oauth/token' - - self.prepare_temporary_credential() - params = [ - ('oauth_consumer_key', 'client'), - ('oauth_token', 'abc'), - ('oauth_verifier', 'abc-verifier'), - ('oauth_signature_method', 'RSA-SHA1'), - ('oauth_timestamp', str(int(time.time()))), - ('oauth_nonce', 'rsa-sha1-nonce'), - ] - base_string = signature.construct_base_string( - 'POST', 'http://localhost/oauth/token', params - ) - sig = signature.rsa_sha1_signature( - base_string, read_file_path('rsa_private.pem')) - params.append(('oauth_signature', sig)) - auth_param = ','.join(['{}="{}"'.format(k, v) for k, v in params]) - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertIn('oauth_token', data) - - # case: invalid signature - self.prepare_temporary_credential() - auth_param = auth_param.replace('rsa-sha1-nonce', 'alt-sha1-nonce') - auth_header = 'OAuth ' + auth_param - headers = {'Authorization': auth_header} - rv = self.client.post(url, headers=headers) - data = decode_response(rv.data) - self.assertEqual(data['error'], 'invalid_signature') +from tests.util import decode_response +from tests.util import read_file_path + +from .oauth1_server import Client +from .oauth1_server import User +from .oauth1_server import create_authorization_server + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(db, user): + client = Client( + user_id=user.id, + client_id="client", + client_secret="secret", + default_redirect_uri="https://client.test", + ) + db.session.add(client) + db.session.commit() + yield client + db.session.delete(client) + + +def prepare_temporary_credential(server): + credential = { + "oauth_token": "abc", + "oauth_token_secret": "abc-secret", + "oauth_verifier": "abc-verifier", + "user": 1, + } + func = server._hooks["create_temporary_credential"] + func(credential, "client", "oob") + + +def test_invalid_token_request_parameters(app, test_client): + create_authorization_server(app, use_cache=True) + url = "/oauth/token" + + # case 1 + rv = test_client.post(url) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_consumer_key" in data["error_description"] + + # case 2 + rv = test_client.post(url, data={"oauth_consumer_key": "a"}) + data = decode_response(rv.data) + assert data["error"] == "invalid_client" + + # case 3 + rv = test_client.post(url, data={"oauth_consumer_key": "client"}) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_token" in data["error_description"] + + # case 4 + rv = test_client.post( + url, data={"oauth_consumer_key": "client", "oauth_token": "a"} + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_token" + + +def test_invalid_token_and_verifiers(app, test_client): + server = create_authorization_server(app, use_cache=True) + url = "/oauth/token" + hook = server._hooks["create_temporary_credential"] + + # case 5 + hook({"oauth_token": "abc", "oauth_token_secret": "abc-secret"}, "client", "oob") + rv = test_client.post( + url, data={"oauth_consumer_key": "client", "oauth_token": "abc"} + ) + data = decode_response(rv.data) + assert data["error"] == "missing_required_parameter" + assert "oauth_verifier" in data["error_description"] + + # case 6 + hook({"oauth_token": "abc", "oauth_token_secret": "abc-secret"}, "client", "oob") + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_token": "abc", + "oauth_verifier": "abc", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_request" + assert "oauth_verifier" in data["error_description"] + + +def test_duplicated_oauth_parameters(app, test_client): + create_authorization_server(app, use_cache=True) + url = "/oauth/token?oauth_consumer_key=client" + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_token": "abc", + "oauth_verifier": "abc", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "duplicated_oauth_protocol_parameter" + + +def test_plaintext_signature(app, test_client): + server = create_authorization_server(app, use_cache=True) + url = "/oauth/token" + + # case 1: success + prepare_temporary_credential(server) + auth_header = ( + 'OAuth oauth_consumer_key="client",' + 'oauth_signature_method="PLAINTEXT",' + 'oauth_token="abc",' + 'oauth_verifier="abc-verifier",' + 'oauth_signature="secret&abc-secret"' + ) + headers = {"Authorization": auth_header} + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case 2: invalid signature + prepare_temporary_credential(server) + rv = test_client.post( + url, + data={ + "oauth_consumer_key": "client", + "oauth_signature_method": "PLAINTEXT", + "oauth_token": "abc", + "oauth_verifier": "abc-verifier", + "oauth_signature": "invalid-signature", + }, + ) + data = decode_response(rv.data) + assert data["error"] == "invalid_signature" + + +def test_hmac_sha1_signature(app, test_client): + server = create_authorization_server(app, use_cache=True) + url = "/oauth/token" + + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "abc"), + ("oauth_verifier", "abc-verifier"), + ("oauth_signature_method", "HMAC-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "hmac-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "POST", "http://localhost/oauth/token", params + ) + sig = signature.hmac_sha1_signature(base_string, "secret", "abc-secret") + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + + # case 1: success + prepare_temporary_credential(server) + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case 2: exists nonce + prepare_temporary_credential(server) + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert data["error"] == "invalid_nonce" + + +def test_rsa_sha1_signature(app, test_client): + server = create_authorization_server(app, use_cache=True) + url = "/oauth/token" + + prepare_temporary_credential(server) + params = [ + ("oauth_consumer_key", "client"), + ("oauth_token", "abc"), + ("oauth_verifier", "abc-verifier"), + ("oauth_signature_method", "RSA-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_nonce", "rsa-sha1-nonce"), + ] + base_string = signature.construct_base_string( + "POST", "http://localhost/oauth/token", params + ) + sig = signature.rsa_sha1_signature(base_string, read_file_path("rsa_private.pem")) + params.append(("oauth_signature", sig)) + auth_param = ",".join([f'{k}="{v}"' for k, v in params]) + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert "oauth_token" in data + + # case: invalid signature + prepare_temporary_credential(server) + auth_param = auth_param.replace("rsa-sha1-nonce", "alt-sha1-nonce") + auth_header = "OAuth " + auth_param + headers = {"Authorization": auth_header} + rv = test_client.post(url, headers=headers) + data = decode_response(rv.data) + assert data["error"] == "invalid_signature" diff --git a/tests/flask/test_oauth2/conftest.py b/tests/flask/test_oauth2/conftest.py new file mode 100644 index 000000000..2ad628b08 --- /dev/null +++ b/tests/flask/test_oauth2/conftest.py @@ -0,0 +1,103 @@ +import os + +import pytest +from flask import Flask + +from tests.flask.test_oauth2.oauth2_server import create_authorization_server + +from .models import Client +from .models import Token +from .models import User + + +@pytest.fixture(autouse=True) +def env(): + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" + yield + del os.environ["AUTHLIB_INSECURE_TRANSPORT"] + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.debug = True + app.testing = True + app.secret_key = "testing" + app.config.update( + { + "SQLALCHEMY_TRACK_MODIFICATIONS": False, + "SQLALCHEMY_DATABASE_URI": "sqlite://", + "OAUTH2_ERROR_URIS": [ + ("invalid_client", "https://client.test/error#invalid_client") + ], + } + ) + with app.app_context(): + yield app + + +@pytest.fixture +def db(app): + from .models import db + + db.init_app(app) + db.create_all() + yield db + db.drop_all() + + +@pytest.fixture +def test_client(app): + return app.test_client() + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture +def client(db, user): + client = Client( + user_id=user.id, + client_id="client-id", + client_secret="client-secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/authorized"], + "scope": "profile", + "grant_types": ["authorization_code"], + "response_types": ["code"], + } + ) + db.session.add(client) + db.session.commit() + yield client + db.session.delete(client) + + +@pytest.fixture +def server(app): + return create_authorization_server(app) + + +@pytest.fixture +def token(db): + token = Token( + user_id=1, + client_id="client-id", + token_type="bearer", + access_token="a1", + refresh_token="r1", + scope="profile", + expires_in=3600, + ) + db.session.add(token) + db.session.commit() + yield token + db.session.delete(token) diff --git a/tests/flask/test_oauth2/models.py b/tests/flask/test_oauth2/models.py index b04f24cbd..9ebe68baf 100644 --- a/tests/flask/test_oauth2/models.py +++ b/tests/flask/test_oauth2/models.py @@ -1,11 +1,10 @@ -import time from flask_sqlalchemy import SQLAlchemy -from authlib.integrations.sqla_oauth2 import ( - OAuth2ClientMixin, - OAuth2TokenMixin, - OAuth2AuthorizationCodeMixin, -) + +from authlib.integrations.sqla_oauth2 import OAuth2AuthorizationCodeMixin +from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin +from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin from authlib.oidc.core import UserInfo + db = SQLAlchemy() @@ -17,19 +16,45 @@ def get_user_id(self): return self.id def check_password(self, password): - return password != 'wrong' - - def generate_user_info(self, scopes): - profile = {'sub': str(self.id), 'name': self.username} + return password != "wrong" + + def generate_user_info(self, scopes=None): + profile = { + "sub": str(self.id), + "name": self.username, + "given_name": "Jane", + "family_name": "Doe", + "middle_name": "Middle", + "nickname": "Jany", + "preferred_username": "j.doe", + "profile": "https://resource.test/janedoe", + "picture": "https://resource.test/janedoe/me.jpg", + "website": "https://resource.test", + "email": "janedoe@example.com", + "email_verified": True, + "gender": "female", + "birthdate": "2000-12-01", + "zoneinfo": "Europe/Paris", + "locale": "fr-FR", + "phone_number": "+1 (425) 555-1212", + "phone_number_verified": False, + "address": { + "formatted": "742 Evergreen Terrace, Springfield", + "street_address": "742 Evergreen Terrace", + "locality": "Springfield", + "region": "Unknown", + "postal_code": "1245", + "country": "USA", + }, + "updated_at": 1745315119, + } return UserInfo(profile) class Client(db.Model, OAuth2ClientMixin): id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') - ) - user = db.relationship('User') + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") class AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): @@ -38,25 +63,29 @@ class AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): @property def user(self): - return User.query.get(self.user_id) + return db.session.get(User, self.user_id) class Token(db.Model, OAuth2TokenMixin): id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') - ) - user = db.relationship('User') + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + def is_refresh_token_active(self): + return not self.refresh_token_revoked_at + + def get_client(self): + return db.session.query(Client).filter_by(client_id=self.client_id).one() - def is_refresh_token_expired(self): - expired_at = self.issued_at + self.expires_in * 2 - return expired_at < time.time() + def get_user(self): + return self.user -class CodeGrantMixin(object): +class CodeGrantMixin: def query_authorization_code(self, code, client): item = AuthorizationCode.query.filter_by( - code=code, client_id=client.client_id).first() + code=code, client_id=client.client_id + ).first() if item and not item.is_expired(): return item @@ -65,7 +94,7 @@ def delete_authorization_code(self, authorization_code): db.session.commit() def authenticate_user(self, authorization_code): - return User.query.get(authorization_code.user_id) + return db.session.get(User, authorization_code.user_id) def save_authorization_code(code, request): @@ -73,20 +102,22 @@ def save_authorization_code(code, request): auth_code = AuthorizationCode( code=code, client_id=client.client_id, - redirect_uri=request.redirect_uri, + redirect_uri=request.payload.redirect_uri, scope=request.scope, - nonce=request.data.get('nonce'), + nonce=request.payload.data.get("nonce"), user_id=request.user.id, - code_challenge=request.data.get('code_challenge'), - code_challenge_method = request.data.get('code_challenge_method'), + code_challenge=request.payload.data.get("code_challenge"), + code_challenge_method=request.payload.data.get("code_challenge_method"), + acr="urn:mace:incommon:iap:silver", + amr="pwd otp", ) db.session.add(auth_code) db.session.commit() return auth_code -def exists_nonce(nonce, req): +def exists_nonce(nonce, request): exists = AuthorizationCode.query.filter_by( - client_id=req.client_id, nonce=nonce + client_id=request.payload.client_id, nonce=nonce ).first() return bool(exists) diff --git a/tests/flask/test_oauth2/oauth2_server.py b/tests/flask/test_oauth2/oauth2_server.py index 7b7cdc47b..722287ef1 100644 --- a/tests/flask/test_oauth2/oauth2_server.py +++ b/tests/flask/test_oauth2/oauth2_server.py @@ -1,26 +1,27 @@ -import os import base64 -import unittest -from flask import Flask, request + +from flask import Flask +from flask import request + +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode from authlib.common.security import generate_token -from authlib.common.encoding import to_bytes, to_unicode -from authlib.common.urls import url_encode -from authlib.integrations.sqla_oauth2 import ( - create_query_client_func, - create_save_token_func, -) from authlib.integrations.flask_oauth2 import AuthorizationServer +from authlib.integrations.sqla_oauth2 import create_query_client_func +from authlib.integrations.sqla_oauth2 import create_save_token_func from authlib.oauth2 import OAuth2Error -from .models import db, User, Client, Token -os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' +from .models import Client +from .models import Token +from .models import User +from .models import db def token_generator(client, grant_type, user=None, scope=None): - token = '{}-{}'.format(client.client_id[0], grant_type) + token = f"{client.client_id[0]}-{grant_type}" if user: - token = '{}.{}'.format(token, user.get_user_id()) - return '{}.{}'.format(token, generate_token(32)) + token = f"{token}.{user.get_user_id()}" + return f"{token}.{generate_token(32)}" def create_authorization_server(app, lazy=False): @@ -33,29 +34,28 @@ def create_authorization_server(app, lazy=False): else: server = AuthorizationServer(app, query_client, save_token) - @app.route('/oauth/authorize', methods=['GET', 'POST']) + @app.route("/oauth/authorize", methods=["GET", "POST"]) def authorize(): - if request.method == 'GET': - user_id = request.args.get('user_id') - if user_id: - end_user = User.query.get(int(user_id)) - else: - end_user = None - try: - grant = server.validate_consent_request(end_user=end_user) - return grant.prompt or 'ok' - except OAuth2Error as error: - return url_encode(error.get_body()) - user_id = request.form.get('user_id') + user_id = request.values.get("user_id") if user_id: - grant_user = User.query.get(int(user_id)) + end_user = db.session.get(User, int(user_id)) else: - grant_user = None - return server.create_authorization_response(grant_user=grant_user) + end_user = None + + try: + grant = server.get_consent_grant(end_user=end_user) + except OAuth2Error as error: + return server.handle_error_response(request, error) + + if request.method == "GET": + return grant.prompt or "ok" + + return server.create_authorization_response(grant=grant, grant_user=end_user) - @app.route('/oauth/token', methods=['GET', 'POST']) + @app.route("/oauth/token", methods=["GET", "POST"]) def issue_token(): return server.create_token_response() + return server @@ -63,35 +63,24 @@ def create_flask_app(): app = Flask(__name__) app.debug = True app.testing = True - app.secret_key = 'testing' - app.config.update({ - 'SQLALCHEMY_TRACK_MODIFICATIONS': False, - 'SQLALCHEMY_DATABASE_URI': 'sqlite://', - 'OAUTH2_ERROR_URIS': [ - ('invalid_client', 'https://a.b/e#invalid_client') - ] - }) + app.secret_key = "testing" + app.config.update( + { + "SQLALCHEMY_TRACK_MODIFICATIONS": False, + "SQLALCHEMY_DATABASE_URI": "sqlite://", + "OAUTH2_ERROR_URIS": [ + ("invalid_client", "https://client.test/error#invalid_client") + ], + } + ) return app -class TestCase(unittest.TestCase): - def setUp(self): - app = create_flask_app() - - self._ctx = app.app_context() - self._ctx.push() - - db.init_app(app) - db.create_all() - - self.app = app - self.client = app.test_client() +def create_basic_header(username, password): + text = f"{username}:{password}" + auth = to_unicode(base64.b64encode(to_bytes(text))) + return {"Authorization": "Basic " + auth} - def tearDown(self): - db.drop_all() - self._ctx.pop() - def create_basic_header(self, username, password): - text = '{}:{}'.format(username, password) - auth = to_unicode(base64.b64encode(to_bytes(text))) - return {'Authorization': 'Basic ' + auth} +def create_bearer_header(token): + return {"Authorization": "Bearer " + token} diff --git a/tests/flask/test_oauth2/rfc9068/__init__.py b/tests/flask/test_oauth2/rfc9068/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/flask/test_oauth2/rfc9068/test_resource_server.py b/tests/flask/test_oauth2/rfc9068/test_resource_server.py new file mode 100644 index 000000000..0e665df5e --- /dev/null +++ b/tests/flask/test_oauth2/rfc9068/test_resource_server.py @@ -0,0 +1,367 @@ +import time + +import pytest +from flask import json +from flask import jsonify +from joserfc import jwt +from joserfc.jwk import KeySet + +from authlib.common.security import generate_token +from authlib.integrations.flask_oauth2 import ResourceProtector +from authlib.integrations.flask_oauth2 import current_token +from authlib.oauth2.rfc9068 import JWTBearerTokenValidator +from tests.util import read_file_path + +from ..models import Token +from ..models import User +from ..models import db + +issuer = "https://provider.test/" +resource_server = "resource-server-id" + + +@pytest.fixture(autouse=True) +def token_validator(jwks): + class MyJWTBearerTokenValidator(JWTBearerTokenValidator): + def get_jwks(self): + return jwks + + validator = MyJWTBearerTokenValidator( + issuer=issuer, resource_server=resource_server + ) + return validator + + +@pytest.fixture(autouse=True) +def resource_protector(app, token_validator): + require_oauth = ResourceProtector() + require_oauth.register_token_validator(token_validator) + + @app.route("/protected") + @require_oauth() + def protected(): + user = db.session.get(User, current_token["sub"]) + return jsonify( + id=user.id, + username=user.username, + token=current_token._get_current_object(), + ) + + @app.route("/protected-by-scope") + @require_oauth("profile") + def protected_by_scope(): + user = db.session.get(User, current_token["sub"]) + return jsonify( + id=user.id, + username=user.username, + token=current_token._get_current_object(), + ) + + @app.route("/protected-by-groups") + @require_oauth(groups=["admins"]) + def protected_by_groups(): + user = db.session.get(User, current_token["sub"]) + return jsonify( + id=user.id, + username=user.username, + token=current_token._get_current_object(), + ) + + @app.route("/protected-by-roles") + @require_oauth(roles=["student"]) + def protected_by_roles(): + user = db.session.get(User, current_token["sub"]) + return jsonify( + id=user.id, + username=user.username, + token=current_token._get_current_object(), + ) + + @app.route("/protected-by-entitlements") + @require_oauth(entitlements=["captain"]) + def protected_by_entitlements(): + user = db.session.get(User, current_token["sub"]) + return jsonify( + id=user.id, + username=user.username, + token=current_token._get_current_object(), + ) + + return require_oauth + + +@pytest.fixture +def jwks(): + return KeySet.import_key_set(read_file_path("jwks_private.json")) + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def create_access_token_claims(client, user): + now = int(time.time()) + expires_in = now + 3600 + auth_time = now - 60 + + return { + "iss": issuer, + "exp": expires_in, + "aud": resource_server, + "sub": user.get_user_id(), + "client_id": client.client_id, + "iat": now, + "jti": generate_token(16), + "auth_time": auth_time, + "scope": client.scope, + "groups": ["admins"], + "roles": ["student"], + "entitlements": ["captain"], + } + + +@pytest.fixture(autouse=True) +def claims(client, user): + return create_access_token_claims(client, user) + + +def create_access_token(claims, jwks, alg="RS256", typ="at+jwt"): + return jwt.encode( + {"alg": alg, "typ": typ}, + claims, + key=jwks, + ) + + +@pytest.fixture +def access_token(claims, jwks): + return create_access_token(claims, jwks) + + +@pytest.fixture +def token(access_token, user): + token = Token( + user_id=user.user_id, + client_id="resource-server", + token_type="bearer", + access_token=access_token, + scope="profile", + expires_in=3600, + ) + db.session.add(token) + db.session.commit() + yield token + db.session.delete(token) + + +def test_access_resource(test_client, access_token): + headers = {"Authorization": f"Bearer {access_token}"} + + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + +def test_missing_authorization(test_client): + rv = test_client.get("/protected") + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "missing_authorization" + + +def test_unsupported_token_type(test_client): + headers = {"Authorization": "invalid token"} + rv = test_client.get("/protected", headers=headers) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" + + +def test_invalid_token(test_client): + headers = {"Authorization": "Bearer invalid"} + rv = test_client.get("/protected", headers=headers) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_typ(test_client, access_token, claims, jwks): + """The resource server MUST verify that the 'typ' header value is 'at+jwt' or + 'application/at+jwt' and reject tokens carrying any other value. + """ + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + access_token = create_access_token(claims, jwks, typ="application/at+jwt") + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + access_token = create_access_token(claims, jwks, typ="invalid") + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_missing_required_claims(test_client, client, user, jwks): + required_claims = ["iss", "exp", "aud", "sub", "client_id", "iat", "jti"] + for claim in required_claims: + claims = create_access_token_claims(client, user) + del claims[claim] + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_invalid_iss(test_client, claims, jwks): + """The issuer identifier for the authorization server (which is typically obtained + during discovery) MUST exactly match the value of the 'iss' claim. + """ + claims["iss"] = "invalid-issuer" + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_invalid_aud(test_client, claims, jwks): + """The resource server MUST validate that the 'aud' claim contains a resource + indicator value corresponding to an identifier the resource server expects for + itself. The JWT access token MUST be rejected if 'aud' does not contain a + resource indicator of the current resource server as a valid audience. + """ + claims["aud"] = "invalid-resource-indicator" + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_invalid_exp(test_client, claims, jwks): + """The current time MUST be before the time represented by the 'exp' claim. + Implementers MAY provide for some small leeway, usually no more than a few + minutes, to account for clock skew. + """ + claims["exp"] = time.time() - 1 + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_scope_restriction(test_client, claims, jwks): + """If an authorization request includes a scope parameter, the corresponding + issued JWT access token SHOULD include a 'scope' claim as defined in Section + 4.2 of [RFC8693]. All the individual scope strings in the 'scope' claim MUST + have meaning for the resources indicated in the 'aud' claim. See Section 5 for + more considerations about the relationship between scope strings and resources + indicated by the 'aud' claim. + """ + claims["scope"] = ["invalid-scope"] + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + rv = test_client.get("/protected-by-scope", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "insufficient_scope" + + +def test_entitlements_restriction(test_client, client, user, jwks): + """Many authorization servers embed authorization attributes that go beyond the + delegated scenarios described by [RFC7519] in the access tokens they issue. + Typical examples include resource owner memberships in roles and groups that + are relevant to the resource being accessed, entitlements assigned to the + resource owner for the targeted resource that the authorization server knows + about, and so on. An authorization server wanting to include such attributes + in a JWT access token SHOULD use the 'groups', 'roles', and 'entitlements' + attributes of the 'User' resource schema defined by Section 4.1.2 of + [RFC7643]) as claim types. + """ + for claim in ["groups", "roles", "entitlements"]: + claims = create_access_token_claims(client, user) + claims[claim] = ["invalid"] + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + rv = test_client.get(f"/protected-by-{claim}", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_extra_attributes(test_client, claims, jwks): + """Authorization servers MAY return arbitrary attributes not defined in any + existing specification, as long as the corresponding claim names are collision + resistant or the access tokens are meant to be used only within a private + subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details. + """ + claims["email"] = "user@example.org" + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["token"]["email"] == "user@example.org" + + +def test_invalid_auth_time(test_client, claims, jwks): + claims["auth_time"] = "invalid-auth-time" + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_invalid_amr(test_client, claims, jwks): + claims["amr"] = "invalid-amr" + access_token = create_access_token(claims, jwks) + + headers = {"Authorization": f"Bearer {access_token}"} + rv = test_client.get("/protected", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" diff --git a/tests/flask/test_oauth2/rfc9068/test_token_generation.py b/tests/flask/test_oauth2/rfc9068/test_token_generation.py new file mode 100644 index 000000000..ed0f4966e --- /dev/null +++ b/tests/flask/test_oauth2/rfc9068/test_token_generation.py @@ -0,0 +1,230 @@ +import pytest + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.jose import jwt +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator +from tests.util import read_file_path + +from ..models import CodeGrantMixin +from ..models import User +from ..models import save_authorization_code + +issuer = "https://authlib.test/" + + +@pytest.fixture +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +@pytest.fixture +def jwks(): + return read_file_path("jwks_private.json") + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(AuthorizationCodeGrant) + return server + + +@pytest.fixture(autouse=True) +def token_generator(server, jwks): + class MyJWTBearerTokenGenerator(JWTBearerTokenGenerator): + def get_jwks(self): + return jwks + + token_generator = MyJWTBearerTokenGenerator(issuer=issuer) + server.register_token_generator("default", token_generator) + return token_generator + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +def test_generate_jwt_access_token(test_client, client, user, jwks): + res = test_client.post( + "/oauth/authorize", + data={ + "response_type": client.response_types[0], + "client_id": client.client_id, + "redirect_uri": client.redirect_uris[0], + "scope": client.scope, + "user_id": user.id, + }, + ) + + params = dict(url_decode(urlparse.urlparse(res.location).query)) + code = params["code"] + res = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": client.client_id, + "client_secret": client.client_secret, + "scope": " ".join(client.scope), + "redirect_uri": client.redirect_uris[0], + }, + ) + + access_token = res.json["access_token"] + claims = jwt.decode(access_token, jwks) + + assert claims["iss"] == issuer + assert claims["sub"] == user.id + assert claims["scope"] == client.scope + assert claims["client_id"] == client.client_id + + # This specification registers the 'application/at+jwt' media type, which can + # be used to indicate that the content is a JWT access token. JWT access tokens + # MUST include this media type in the 'typ' header parameter to explicitly + # declare that the JWT represents an access token complying with this profile. + # Per the definition of 'typ' in Section 4.1.9 of [RFC7515], it is RECOMMENDED + # that the 'application/' prefix be omitted. Therefore, the 'typ' value used + # SHOULD be 'at+jwt'. + + assert claims.header["typ"] == "at+jwt" + + +def test_generate_jwt_access_token_extra_claims( + test_client, token_generator, user, client, jwks +): + """Authorization servers MAY return arbitrary attributes not defined in any + existing specification, as long as the corresponding claim names are collision + resistant or the access tokens are meant to be used only within a private + subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details. + """ + + def get_extra_claims(client, grant_type, user, scope): + return {"username": user.username} + + token_generator.get_extra_claims = get_extra_claims + + res = test_client.post( + "/oauth/authorize", + data={ + "response_type": client.response_types[0], + "client_id": client.client_id, + "redirect_uri": client.redirect_uris[0], + "scope": client.scope, + "user_id": user.id, + }, + ) + + params = dict(url_decode(urlparse.urlparse(res.location).query)) + code = params["code"] + res = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": client.client_id, + "client_secret": client.client_secret, + "scope": " ".join(client.scope), + "redirect_uri": client.redirect_uris[0], + }, + ) + + access_token = res.json["access_token"] + claims = jwt.decode(access_token, jwks) + assert claims["username"] == user.username + + +@pytest.mark.skip +def test_generate_jwt_access_token_no_user(test_client, client, user, jwks): + res = test_client.post( + "/oauth/authorize", + data={ + "response_type": client.response_types[0], + "client_id": client.client_id, + "redirect_uri": client.redirect_uris[0], + "scope": client.scope, + #'user_id': user.id, + }, + ) + + params = dict(url_decode(urlparse.urlparse(res.location).query)) + code = params["code"] + res = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": client.client_id, + "client_secret": client.client_secret, + "scope": " ".join(client.scope), + "redirect_uri": client.redirect_uris[0], + }, + ) + + access_token = res.json["access_token"] + claims = jwt.decode(access_token, jwks) + + assert claims["sub"] == client.client_id + + +def test_optional_fields(test_client, token_generator, user, client, jwks): + token_generator.get_auth_time = lambda *args: 1234 + token_generator.get_amr = lambda *args: "amr" + token_generator.get_acr = lambda *args: "acr" + + res = test_client.post( + "/oauth/authorize", + data={ + "response_type": client.response_types[0], + "client_id": client.client_id, + "redirect_uri": client.redirect_uris[0], + "scope": client.scope, + "user_id": user.id, + }, + ) + + params = dict(url_decode(urlparse.urlparse(res.location).query)) + code = params["code"] + res = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": client.client_id, + "client_secret": client.client_secret, + "scope": " ".join(client.scope), + "redirect_uri": client.redirect_uris[0], + }, + ) + + access_token = res.json["access_token"] + claims = jwt.decode(access_token, jwks) + + assert claims["auth_time"] == 1234 + assert claims["amr"] == "amr" + assert claims["acr"] == "acr" diff --git a/tests/flask/test_oauth2/rfc9068/test_token_introspection.py b/tests/flask/test_oauth2/rfc9068/test_token_introspection.py new file mode 100644 index 000000000..cf7a7ef35 --- /dev/null +++ b/tests/flask/test_oauth2/rfc9068/test_token_introspection.py @@ -0,0 +1,255 @@ +import time + +import pytest +from flask import json + +from authlib.common.security import generate_token +from authlib.jose import jwt +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from authlib.oauth2.rfc7662 import IntrospectionEndpoint +from authlib.oauth2.rfc9068 import JWTIntrospectionEndpoint +from tests.util import read_file_path + +from ..models import CodeGrantMixin +from ..models import User +from ..models import db +from ..models import save_authorization_code +from ..oauth2_server import create_basic_header + +issuer = "https://provider.test/" +resource_server = "resource-server-id" + + +@pytest.fixture +def jwks(): + return read_file_path("jwks_private.json") + + +@pytest.fixture(autouse=True) +def server(server): + class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + "client_secret_basic", + "client_secret_post", + "none", + ] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + server.register_grant(AuthorizationCodeGrant) + return server + + +@pytest.fixture(autouse=True) +def introspection_endpoint(server, app, jwks): + class MyJWTIntrospectionEndpoint(JWTIntrospectionEndpoint): + def get_jwks(self): + return jwks + + def check_permission(self, token, client, request): + return client.client_id == "client-id" + + endpoint = MyJWTIntrospectionEndpoint(issuer=issuer) + server.register_endpoint(endpoint) + + @app.route("/oauth/introspect", methods=["POST"]) + def introspect_token(): + return server.create_endpoint_response(MyJWTIntrospectionEndpoint.ENDPOINT_NAME) + + return endpoint + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def create_access_token_claims(client, user): + now = int(time.time()) + expires_in = now + 3600 + auth_time = now - 60 + + return { + "iss": issuer, + "exp": expires_in, + "aud": [resource_server], + "sub": user.get_user_id(), + "client_id": client.client_id, + "iat": now, + "jti": generate_token(16), + "auth_time": auth_time, + "scope": client.scope, + "groups": ["admins"], + "roles": ["student"], + "entitlements": ["captain"], + } + + +@pytest.fixture +def claims(client, user): + return create_access_token_claims(client, user) + + +def create_access_token(claims, jwks, alg="RS256", typ="at+jwt"): + header = {"alg": alg, "typ": typ} + access_token = jwt.encode( + header, + claims, + key=jwks, + check=False, + ) + return access_token.decode() + + +@pytest.fixture +def access_token(claims, jwks): + return create_access_token(claims, jwks) + + +def test_introspection(test_client, client, user, access_token): + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp["active"] + assert resp["client_id"] == client.client_id + assert resp["token_type"] == "Bearer" + assert resp["scope"] == client.scope + assert resp["sub"] == user.id + assert resp["aud"] == [resource_server] + assert resp["iss"] == issuer + + +def test_introspection_username( + test_client, client, user, introspection_endpoint, access_token +): + introspection_endpoint.get_username = lambda user_id: ( + db.session.get(User, user_id).username + ) + + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp["active"] + assert resp["username"] == user.username + + +def test_non_access_token_skipped(test_client, client, server): + class MyIntrospectionEndpoint(IntrospectionEndpoint): + def query_token(self, token, token_type_hint): + return None + + server.register_endpoint(MyIntrospectionEndpoint) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "refresh-token", + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert not resp["active"] + + +def test_access_token_non_jwt_skipped(test_client, client, server): + class MyIntrospectionEndpoint(IntrospectionEndpoint): + def query_token(self, token, token_type_hint): + return None + + server.register_endpoint(MyIntrospectionEndpoint) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "non-jwt-access-token", + }, + headers=headers, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert not resp["active"] + + +def test_permission_denied(test_client, introspection_endpoint, access_token, client): + introspection_endpoint.check_permission = lambda *args: False + + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert not resp["active"] + + +def test_token_expired(test_client, claims, client, jwks): + claims["exp"] = time.time() - 3600 + access_token = create_access_token(claims, jwks) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert not resp["active"] + + +def test_introspection_different_issuer(test_client, server, claims, client, jwks): + class MyIntrospectionEndpoint(IntrospectionEndpoint): + def query_token(self, token, token_type_hint): + return None + + server.register_endpoint(MyIntrospectionEndpoint) + + claims["iss"] = "different-issuer" + access_token = create_access_token(claims, jwks) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert not resp["active"] + + +def test_introspection_invalid_claim(test_client, claims, client, jwks): + claims["exp"] = "invalid" + access_token = create_access_token(claims, jwks) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/introspect", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" diff --git a/tests/flask/test_oauth2/rfc9068/test_token_revocation.py b/tests/flask/test_oauth2/rfc9068/test_token_revocation.py new file mode 100644 index 000000000..a0466781d --- /dev/null +++ b/tests/flask/test_oauth2/rfc9068/test_token_revocation.py @@ -0,0 +1,187 @@ +import time + +import pytest +from flask import json + +from authlib.common.security import generate_token +from authlib.jose import jwt +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from authlib.oauth2.rfc7009 import RevocationEndpoint +from authlib.oauth2.rfc9068 import JWTRevocationEndpoint +from tests.util import read_file_path + +from ..models import CodeGrantMixin +from ..models import User +from ..models import save_authorization_code +from ..oauth2_server import create_basic_header + +issuer = "https://provider.test/" +resource_server = "resource-server-id" + + +@pytest.fixture +def jwks(): + return read_file_path("jwks_private.json") + + +@pytest.fixture(autouse=True) +def server(server): + class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + "client_secret_basic", + "client_secret_post", + "none", + ] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + server.register_grant(AuthorizationCodeGrant) + return server + + +@pytest.fixture(autouse=True) +def revocation_endpoint(app, server, jwks): + class MyJWTRevocationEndpoint(JWTRevocationEndpoint): + def get_jwks(self): + return jwks + + endpoint = MyJWTRevocationEndpoint(issuer=issuer) + server.register_endpoint(endpoint) + + @app.route("/oauth/revoke", methods=["POST"]) + def revoke_token(): + return server.create_endpoint_response(MyJWTRevocationEndpoint.ENDPOINT_NAME) + + return endpoint + + +@pytest.fixture(autouse=True) +def user(db): + user = User(username="foo") + db.session.add(user) + db.session.commit() + yield user + db.session.delete(user) + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def create_access_token_claims(client, user): + now = int(time.time()) + expires_in = now + 3600 + auth_time = now - 60 + + return { + "iss": issuer, + "exp": expires_in, + "aud": [resource_server], + "sub": user.get_user_id(), + "client_id": client.client_id, + "iat": now, + "jti": generate_token(16), + "auth_time": auth_time, + "scope": client.scope, + "groups": ["admins"], + "roles": ["student"], + "entitlements": ["captain"], + } + + +@pytest.fixture +def claims(client, user): + return create_access_token_claims(client, user) + + +def create_access_token(claims, jwks, alg="RS256", typ="at+jwt"): + header = {"alg": alg, "typ": typ} + access_token = jwt.encode( + header, + claims, + key=jwks, + check=False, + ) + return access_token.decode() + + +@pytest.fixture +def access_token(claims, jwks): + return create_access_token(claims, jwks) + + +def test_revocation(test_client, client, access_token): + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/revoke", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" + + +def test_non_access_token_skipped(test_client, server, client): + class MyRevocationEndpoint(RevocationEndpoint): + def query_token(self, token, token_type_hint): + return None + + server.register_endpoint(MyRevocationEndpoint) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "refresh-token", + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp == {} + + +def test_access_token_non_jwt_skipped(test_client, server, client): + class MyRevocationEndpoint(RevocationEndpoint): + def query_token(self, token, token_type_hint): + return None + + server.register_endpoint(MyRevocationEndpoint) + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "non-jwt-access-token", + }, + headers=headers, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp == {} + + +def test_revocation_different_issuer(test_client, claims, jwks, client): + claims["iss"] = "different-issuer" + access_token = create_access_token(claims, jwks) + + headers = create_basic_header(client.client_id, client.client_secret) + rv = test_client.post( + "/oauth/revoke", data={"token": access_token}, headers=headers + ) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" diff --git a/tests/flask/test_oauth2/test_authorization_code_grant.py b/tests/flask/test_oauth2/test_authorization_code_grant.py index 8698c31f3..750bc733b 100644 --- a/tests/flask/test_oauth2/test_authorization_code_grant.py +++ b/tests/flask/test_oauth2/test_authorization_code_grant.py @@ -1,254 +1,432 @@ +import pytest from flask import json -from authlib.common.urls import urlparse, url_decode + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ) -from .models import db, User, Client, AuthorizationCode -from .models import CodeGrantMixin, save_authorization_code -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server + +from .models import AuthorizationCode +from .models import CodeGrantMixin +from .models import db +from .models import save_authorization_code +from .oauth2_server import create_basic_header + +authorize_url = "/oauth/authorize?response_type=code&client_id=client-id" + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] def save_authorization_code(self, code, request): return save_authorization_code(code, request) -class AuthorizationCodeTest(TestCase): - LAZY_INIT = False - - def register_grant(self, server): - server.register_grant(AuthorizationCodeGrant) - - def prepare_data( - self, is_confidential=True, - response_type='code', grant_type='authorization_code', - token_endpoint_auth_method='client_secret_basic'): - server = create_authorization_server(self.app, self.LAZY_INIT) - self.register_grant(server) - self.server = server - - user = User(username='foo') - db.session.add(user) - db.session.commit() - - if is_confidential: - client_secret = 'code-secret' - else: - client_secret = '' - client = Client( - user_id=user.id, - client_id='code-client', - client_secret=client_secret, - ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b'], - 'scope': 'profile address', - 'token_endpoint_auth_method': token_endpoint_auth_method, - 'response_types': [response_type], - 'grant_types': grant_type.splitlines(), - }) - self.authorize_url = ( - '/oauth/authorize?response_type=code' - '&client_id=code-client' - ) - db.session.add(client) - db.session.commit() - - def test_get_authorize(self): - self.prepare_data() - rv = self.client.get(self.authorize_url) - self.assertEqual(rv.data, b'ok') - - def test_invalid_client_id(self): - self.prepare_data() - url = '/oauth/authorize?response_type=code' - rv = self.client.get(url) - self.assertIn(b'invalid_client', rv.data) - - url = '/oauth/authorize?response_type=code&client_id=invalid' - rv = self.client.get(url) - self.assertIn(b'invalid_client', rv.data) - - def test_invalid_authorize(self): - self.prepare_data() - rv = self.client.post(self.authorize_url) - self.assertIn('error=access_denied', rv.location) - - self.server.metadata = {'scopes_supported': ['profile']} - rv = self.client.post(self.authorize_url + '&scope=invalid&state=foo') - self.assertIn('error=invalid_scope', rv.location) - self.assertIn('state=foo', rv.location) - - def test_unauthorized_client(self): - self.prepare_data(True, 'token') - rv = self.client.get(self.authorize_url) - self.assertIn(b'unauthorized_client', rv.data) - - def test_invalid_client(self): - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'invalid', - 'client_id': 'invalid-id', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header('code-client', 'invalid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'invalid', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - self.assertEqual(resp['error_uri'], 'https://a.b/e#invalid_client') - - def test_invalid_code(self): - self.prepare_data() - - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'invalid', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - code = AuthorizationCode( - code='no-user', - client_id='code-client', - user_id=0 - ) - db.session.add(code) - db.session.commit() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'no-user', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_invalid_redirect_uri(self): - self.prepare_data() - uri = self.authorize_url + '&redirect_uri=https%3A%2F%2Fa.c' - rv = self.client.post(uri, data={'user_id': '1'}) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - uri = self.authorize_url + '&redirect_uri=https%3A%2F%2Fa.b' - rv = self.client.post(uri, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_invalid_grant_type(self): - self.prepare_data( - False, token_endpoint_auth_method='none', - grant_type='invalid' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'client_id': 'code-client', - 'code': 'a', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - def test_authorize_token_no_refresh_token(self): - self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) - self.prepare_data(False, token_endpoint_auth_method='none') - - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertNotIn('refresh_token', resp) - - def test_authorize_token_has_refresh_token(self): - # generate refresh token - self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) - self.prepare_data(grant_type='authorization_code\nrefresh_token') - url = self.authorize_url + '&state=bar' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('refresh_token', resp) - - def test_client_secret_post(self): - self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) - self.prepare_data( - grant_type='authorization_code\nrefresh_token', - token_endpoint_auth_method='client_secret_post', - ) - url = self.authorize_url + '&state=bar' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'client_id': 'code-client', - 'client_secret': 'code-secret', - 'code': code, - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('refresh_token', resp) - - def test_token_generator(self): - m = 'tests.flask.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - self.prepare_data(False, token_endpoint_auth_method='none') - - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('c-authorization_code.1.', resp['access_token']) +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(AuthorizationCodeGrant) + return server + + +def test_get_authorize(test_client): + rv = test_client.get(authorize_url) + assert rv.data == b"ok" + + +def test_invalid_client_id(test_client): + url = "/oauth/authorize?response_type=code" + rv = test_client.get(url) + assert b"invalid_client" in rv.data + + url = "/oauth/authorize?response_type=code&client_id=invalid" + rv = test_client.get(url) + assert b"invalid_client" in rv.data + + +def test_invalid_authorize(test_client, server): + rv = test_client.post(authorize_url) + assert "error=access_denied" in rv.location + + server.scopes_supported = ["profile"] + rv = test_client.post(authorize_url + "&scope=invalid&state=foo") + assert "error=invalid_scope" in rv.location + assert "state=foo" in rv.location + + +def test_unauthorized_client(test_client, client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["token"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.get(authorize_url) + assert "unauthorized_client" in rv.location + + +def test_invalid_client(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + "client_id": "invalid-id", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("code-client", "invalid-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + assert resp["error_uri"] == "https://client.test/error#invalid_client" + + +def test_invalid_code(test_client): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" + + code = AuthorizationCode(code="no-user", client_id="code-client", user_id=0) + db.session.add(code) + db.session.commit() + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "no-user", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" + + +def test_invalid_redirect_uri(test_client): + uri = authorize_url + "&redirect_uri=https%3A%2F%2Fa.c" + rv = test_client.post(uri, data={"user_id": "1"}) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + uri = authorize_url + "&redirect_uri=https%3A%2F%2Fclient.test" + rv = test_client.post(uri, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" + + +def test_invalid_grant_type(test_client, client, db): + client.client_secret = "" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["invalid"], + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "client_id": "client-id", + "code": "a", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unauthorized_client" + + +def test_authorize_token_no_refresh_token(app, test_client, client, db, server): + app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + server.load_config(app.config) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "refresh_token" not in resp + + +def test_authorize_token_has_refresh_token(app, test_client, client, db, server): + app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + server.load_config(app.config) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code", "refresh_token"], + } + ) + db.session.add(client) + db.session.commit() + + url = authorize_url + "&state=bar" + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["state"] == "bar" + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "refresh_token" in resp + + +def test_invalid_multiple_request_parameters(test_client): + url = ( + authorize_url + + "&scope=profile&state=bar&redirect_uri=https%3A%2F%2Fclient.test&response_type=code" + ) + rv = test_client.get(url) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + assert resp["error_description"] == "Multiple 'response_type' in request." + + +def test_client_secret_post(app, test_client, client, db, server): + app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + server.load_config(app.config) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code"], + "grant_types": ["authorization_code", "refresh_token"], + } + ) + db.session.add(client) + db.session.commit() + + url = authorize_url + "&state=bar" + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["state"] == "bar" + + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "client_id": "client-id", + "client_secret": "client-secret", + "code": code, + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "refresh_token" in resp + + +def test_token_generator(app, test_client, client, server): + m = "tests.flask.test_oauth2.oauth2_server:token_generator" + app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + server.load_config(app.config) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "c-authorization_code.1." in resp["access_token"] + + +def test_missing_scope_uses_default(test_client, client, monkeypatch): + """Per RFC 6749 Section 3.3, when scope is omitted at authorization endpoint, + the server should use a pre-defined default value from client.get_allowed_scope(). + """ + + def get_allowed_scope_with_default(scope): + if scope is None: + return "default_scope" + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_with_default) + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert resp.get("scope") == "default_scope" + + +def test_missing_scope_empty_default(test_client, client, monkeypatch): + """When client.get_allowed_scope() returns empty string for missing scope, + the authorization should proceed without a scope. + """ + + def get_allowed_scope_empty(scope): + if scope is None: + return "" + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_empty) + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert resp.get("scope", "") == "" + + +def test_missing_scope_rejected(test_client, client, monkeypatch): + """Per RFC 6749 Section 3.3, when scope is omitted and client.get_allowed_scope() + returns None, the authorization should fail with invalid_scope. + """ + + def get_allowed_scope_reject(scope): + if scope is None: + return None + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_reject) + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "error=invalid_scope" in rv.location diff --git a/tests/flask/test_oauth2/test_authorization_code_iss_parameter.py b/tests/flask/test_oauth2/test_authorization_code_iss_parameter.py new file mode 100644 index 000000000..72397405c --- /dev/null +++ b/tests/flask/test_oauth2/test_authorization_code_iss_parameter.py @@ -0,0 +1,85 @@ +import pytest + +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from authlib.oauth2.rfc9207 import IssuerParameter as _IssuerParameter + +from .models import CodeGrantMixin +from .models import save_authorization_code + +authorize_url = "/oauth/authorize?response_type=code&client_id=client-id" + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class IssuerParameter(_IssuerParameter): + def get_issuer(self) -> str: + return "https://auth.test" + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(AuthorizationCodeGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_rfc9207_enabled_success(test_client, server): + """Check that when RFC9207 is implemented, + the authorization response has an ``iss`` parameter.""" + + server.register_extension(IssuerParameter()) + url = authorize_url + "&state=bar" + rv = test_client.post(url, data={"user_id": "1"}) + assert "iss=https%3A%2F%2Fauth.test" in rv.location + + +def test_rfc9207_disabled_success_no_iss(test_client): + """Check that when RFC9207 is not implemented, + the authorization response contains no ``iss`` parameter.""" + + url = authorize_url + "&state=bar" + rv = test_client.post(url, data={"user_id": "1"}) + assert "iss=" not in rv.location + + +def test_rfc9207_enabled_error(test_client, server): + """Check that when RFC9207 is implemented, + the authorization response has an ``iss`` parameter, + even when an error is returned.""" + + server.register_extension(IssuerParameter()) + rv = test_client.post(authorize_url) + assert "error=access_denied" in rv.location + assert "iss=https%3A%2F%2Fauth.test" in rv.location + + +def test_rfc9207_disbled_error_no_iss(test_client): + """Check that when RFC9207 is not implemented, + the authorization response contains no ``iss`` parameter, + even when an error is returned.""" + + rv = test_client.post(authorize_url) + assert "error=access_denied" in rv.location + assert "iss=" not in rv.location diff --git a/tests/flask/test_oauth2/test_client_configuration_endpoint.py b/tests/flask/test_oauth2/test_client_configuration_endpoint.py new file mode 100644 index 000000000..fa61433b5 --- /dev/null +++ b/tests/flask/test_oauth2/test_client_configuration_endpoint.py @@ -0,0 +1,495 @@ +import pytest +from flask import json + +from authlib.oauth2.rfc7592 import ( + ClientConfigurationEndpoint as _ClientConfigurationEndpoint, +) + +from .models import Client +from .models import Token +from .models import db + + +class ClientConfigurationEndpoint(_ClientConfigurationEndpoint): + software_statement_alg_values_supported = ["RS256"] + + def authenticate_token(self, request): + auth_header = request.headers.get("Authorization") + if auth_header: + access_token = auth_header.split()[1] + return Token.query.filter_by(access_token=access_token).first() + + def update_client(self, client, client_metadata, request): + client.set_client_metadata(client_metadata) + db.session.add(client) + db.session.commit() + return client + + def authenticate_client(self, request): + client_id = request.uri.split("/")[-1] + return Client.query.filter_by(client_id=client_id).first() + + def revoke_access_token(self, request, token): + token.revoked = True + db.session.add(token) + db.session.commit() + + def check_permission(self, client, request): + client_id = request.uri.split("/")[-1] + return client_id != "unauthorized_client_id" + + def delete_client(self, client, request): + db.session.delete(client) + db.session.commit() + + def generate_client_registration_info(self, client, request): + return { + "registration_client_uri": request.uri, + "registration_access_token": request.headers["Authorization"].split(" ")[1], + } + + +@pytest.fixture +def metadata(): + return {} + + +@pytest.fixture(autouse=True) +def server(server, app, metadata): + @app.route("/configure_client/", methods=["PUT", "GET", "DELETE"]) + def configure_client(client_id): + return server.create_endpoint_response( + ClientConfigurationEndpoint.ENDPOINT_NAME + ) + + class MyClientConfiguration(ClientConfigurationEndpoint): + def get_server_metadata(test_client): + return metadata + + server.register_endpoint(MyClientConfiguration) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "client_name": "Authlib", + "scope": "openid profile", + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_read_client(test_client, client, token): + assert client.client_name == "Authlib" + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.get("/configure_client/client-id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 200 + assert resp["client_id"] == client.client_id + assert resp["client_name"] == "Authlib" + assert ( + resp["registration_client_uri"] == "http://localhost/configure_client/client-id" + ) + assert resp["registration_access_token"] == token.access_token + + +def test_read_access_denied(test_client): + rv = test_client.get("/configure_client/client-id") + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + headers = {"Authorization": "bearer invalid_token"} + rv = test_client.get("/configure_client/client-id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + headers = {"Authorization": "bearer unauthorized_token"} + rv = test_client.get( + "/configure_client/client-id", + json={"client_id": "client-id", "client_name": "new client_name"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + +def test_read_invalid_client(test_client, token): + # If the client does not exist on this server, the server MUST respond + # with HTTP 401 Unauthorized, and the registration access token used to + # make this request SHOULD be immediately revoked. + + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.get("/configure_client/invalid_client_id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 401 + assert resp["error"] == "invalid_client" + + +def test_read_unauthorized_client(test_client, token): + # If the client does not have permission to read its record, the server + # MUST return an HTTP 403 Forbidden. + + client = Client( + client_id="unauthorized_client_id", + client_secret="unauthorized_client_secret", + ) + db.session.add(client) + + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.get("/configure_client/unauthorized_client_id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 403 + assert resp["error"] == "unauthorized_client" + + +def test_update_client(test_client, client, token): + # Valid values of client metadata fields in this request MUST replace, + # not augment, the values previously associated with this client. + # Omitted fields MUST be treated as null or empty values by the server, + # indicating the client's request to delete them from the client's + # registration. The authorization server MAY ignore any null or empty + # value in the request just as any other value. + + assert client.client_name == "Authlib" + headers = {"Authorization": f"bearer {token.access_token}"} + body = { + "client_id": client.client_id, + "client_name": "NewAuthlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 200 + assert resp["client_id"] == client.client_id + assert resp["client_name"] == "NewAuthlib" + assert client.client_name == "NewAuthlib" + assert client.scope == "" + + +def test_update_access_denied(test_client): + rv = test_client.put("/configure_client/client-id", json={}) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + headers = {"Authorization": "bearer invalid_token"} + rv = test_client.put("/configure_client/client-id", json={}, headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + headers = {"Authorization": "bearer unauthorized_token"} + rv = test_client.put( + "/configure_client/client-id", + json={"client_id": "client-id", "client_name": "new client_name"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + +def test_update_invalid_request(test_client, token): + headers = {"Authorization": f"bearer {token.access_token}"} + + # The client MUST include its 'client_id' field in the request... + rv = test_client.put("/configure_client/client-id", json={}, headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "invalid_request" + + # ... and it MUST be the same as its currently issued client identifier. + rv = test_client.put( + "/configure_client/client-id", + json={"client_id": "invalid_client_id"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "invalid_request" + + # The updated client metadata fields request MUST NOT include the + # 'registration_access_token', 'registration_client_uri', + # 'client_secret_expires_at', or 'client_id_issued_at' fields + rv = test_client.put( + "/configure_client/client-id", + json={ + "client_id": "client-id", + "registration_client_uri": "https://client.test", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "invalid_request" + + # If the client includes the 'client_secret' field in the request, + # the value of this field MUST match the currently issued client + # secret for that client. + rv = test_client.put( + "/configure_client/client-id", + json={"client_id": "client-id", "client_secret": "invalid_secret"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "invalid_request" + + +def test_update_invalid_client(test_client, token): + # If the client does not exist on this server, the server MUST respond + # with HTTP 401 Unauthorized, and the registration access token used to + # make this request SHOULD be immediately revoked. + + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.put( + "/configure_client/invalid_client_id", + json={"client_id": "invalid_client_id", "client_name": "new client_name"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 401 + assert resp["error"] == "invalid_client" + + +def test_update_unauthorized_client(test_client, token): + # If the client does not have permission to read its record, the server + # MUST return an HTTP 403 Forbidden. + + client = Client( + client_id="unauthorized_client_id", + client_secret="unauthorized_client_secret", + ) + db.session.add(client) + + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.put( + "/configure_client/unauthorized_client_id", + json={ + "client_id": "unauthorized_client_id", + "client_name": "new client_name", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 403 + assert resp["error"] == "unauthorized_client" + + +def test_update_invalid_metadata(test_client, metadata, client, token): + metadata["token_endpoint_auth_methods_supported"] = ["client_secret_basic"] + headers = {"Authorization": f"bearer {token.access_token}"} + + # For all metadata fields, the authorization server MAY replace any + # invalid values with suitable default values, and it MUST return any + # such fields to the client in the response. + # If the client attempts to set an invalid metadata field and the + # authorization server does not set a default value, the authorization + # server responds with an error as described in [RFC7591]. + + body = { + "client_id": client.client_id, + "client_name": "NewAuthlib", + "token_endpoint_auth_method": "invalid_auth_method", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "invalid_client_metadata" + + +def test_update_scopes_supported(test_client, metadata, token): + metadata["scopes_supported"] = ["profile", "email"] + + headers = {"Authorization": f"bearer {token.access_token}"} + body = { + "client_id": "client-id", + "scope": "profile email", + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" + assert resp["client_name"] == "Authlib" + assert resp["scope"] == "profile email" + + headers = {"Authorization": f"bearer {token.access_token}"} + body = { + "client_id": "client-id", + "scope": "", + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" + assert resp["client_name"] == "Authlib" + + body = { + "client_id": "client-id", + "scope": "profile email address", + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_update_response_types_supported(test_client, metadata, token): + metadata["response_types_supported"] = ["code"] + + headers = {"Authorization": f"bearer {token.access_token}"} + body = { + "client_id": "client-id", + "response_types": ["code"], + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" + assert resp["client_name"] == "Authlib" + assert resp["response_types"] == ["code"] + + # https://datatracker.ietf.org/doc/html/rfc7592#section-2.2 + # If omitted, the default is that the client will use only the "code" + # response type. + body = {"client_id": "client-id", "client_name": "Authlib"} + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert "response_types" not in resp + + body = { + "client_id": "client-id", + "response_types": ["code", "token"], + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_update_grant_types_supported(test_client, metadata, token): + metadata["grant_types_supported"] = ["authorization_code", "password"] + + headers = {"Authorization": f"bearer {token.access_token}"} + body = { + "client_id": "client-id", + "grant_types": ["password"], + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" + assert resp["client_name"] == "Authlib" + assert resp["grant_types"] == ["password"] + + # https://datatracker.ietf.org/doc/html/rfc7592#section-2.2 + # If omitted, the default behavior is that the client will use only + # the "authorization_code" Grant Type. + body = {"client_id": "client-id", "client_name": "Authlib"} + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert "grant_types" not in resp + + body = { + "client_id": "client-id", + "grant_types": ["client_credentials"], + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_update_token_endpoint_auth_methods_supported(test_client, metadata, token): + metadata["token_endpoint_auth_methods_supported"] = ["client_secret_basic"] + + headers = {"Authorization": f"bearer {token.access_token}"} + body = { + "client_id": "client-id", + "token_endpoint_auth_method": "client_secret_basic", + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" + assert resp["client_name"] == "Authlib" + assert resp["token_endpoint_auth_method"] == "client_secret_basic" + + body = { + "client_id": "client-id", + "token_endpoint_auth_method": "none", + "client_name": "Authlib", + } + rv = test_client.put("/configure_client/client-id", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_delete_client(test_client, client, token): + assert client.client_name == "Authlib" + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.delete("/configure_client/client-id", headers=headers) + assert rv.status_code == 204 + assert not rv.data + + +def test_delete_access_denied(test_client): + rv = test_client.delete("/configure_client/client-id") + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + headers = {"Authorization": "bearer invalid_token"} + rv = test_client.delete("/configure_client/client-id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + headers = {"Authorization": "bearer unauthorized_token"} + rv = test_client.delete( + "/configure_client/client-id", + json={"client_id": "client-id", "client_name": "new client_name"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert rv.status_code == 400 + assert resp["error"] == "access_denied" + + +def test_delete_invalid_client(test_client, token): + # If the client does not exist on this server, the server MUST respond + # with HTTP 401 Unauthorized, and the registration access token used to + # make this request SHOULD be immediately revoked. + + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.delete("/configure_client/invalid_client_id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 401 + assert resp["error"] == "invalid_client" + + +def test_delete_unauthorized_client(test_client, token): + # If the client does not have permission to read its record, the server + # MUST return an HTTP 403 Forbidden. + + client = Client( + client_id="unauthorized_client_id", + client_secret="unauthorized_client_secret", + ) + db.session.add(client) + + headers = {"Authorization": f"bearer {token.access_token}"} + rv = test_client.delete("/configure_client/unauthorized_client_id", headers=headers) + resp = json.loads(rv.data) + assert rv.status_code == 403 + assert resp["error"] == "unauthorized_client" diff --git a/tests/flask/test_oauth2/test_client_credentials_grant.py b/tests/flask/test_oauth2/test_client_credentials_grant.py index ec7c9a0a6..009e48c20 100644 --- a/tests/flask/test_oauth2/test_client_credentials_grant.py +++ b/tests/flask/test_oauth2/test_client_credentials_grant.py @@ -1,95 +1,184 @@ +import pytest from flask import json + from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class ClientCredentialsTest(TestCase): - def prepare_data(self, grant_type='client_credentials'): - server = create_authorization_server(self.app) - server.register_grant(ClientCredentialsGrant) - self.server = server - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='credential-client', - client_secret='credential-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - 'grant_types': [grant_type] - }) - db.session.add(client) - db.session.commit() - - def test_invalid_client(self): - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'credential-client', 'invalid-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_invalid_grant_type(self): - self.prepare_data(grant_type='invalid') - headers = self.create_basic_header( - 'credential-client', 'credential-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - def test_invalid_scope(self): - self.prepare_data() - self.server.metadata = {'scopes_supported': ['profile']} - headers = self.create_basic_header( - 'credential-client', 'credential-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'scope': 'invalid', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_scope') - - def test_authorize_token(self): - self.prepare_data() - headers = self.create_basic_header( - 'credential-client', 'credential-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_token_generator(self): - m = 'tests.flask.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - - self.prepare_data() - headers = self.create_basic_header( - 'credential-client', 'credential-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('c-client_credentials.', resp['access_token']) + +from .oauth2_server import create_basic_header + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(ClientCredentialsGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "grant_types": ["client_credentials"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_invalid_client(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("client-id", "invalid-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_invalid_grant_type(test_client, client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "grant_types": ["invalid"], + } + ) + db.session.add(client) + db.session.commit() + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unauthorized_client" + + +def test_invalid_scope(test_client, server): + server.scopes_supported = ["profile"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "scope": "invalid", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_scope" + + +def test_authorize_token(test_client): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_token_generator(test_client, app, server): + m = "tests.flask.test_oauth2.oauth2_server:token_generator" + app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + server.load_config(app.config) + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "c-client_credentials." in resp["access_token"] + + +def test_missing_scope_uses_default(test_client, client, monkeypatch): + """Per RFC 6749 Section 3.3, when scope is omitted, the server should use + a pre-defined default value from client.get_allowed_scope(). + """ + + def get_allowed_scope_with_default(scope): + if scope is None: + return "default_scope" + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_with_default) + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert resp.get("scope") == "default_scope" + + +def test_missing_scope_empty_default(test_client, client, monkeypatch): + """When client.get_allowed_scope() returns empty string for missing scope, + the token should be issued without a scope. + """ + + def get_allowed_scope_empty(scope): + if scope is None: + return "" + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_empty) + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert resp.get("scope", "") == "" + + +def test_missing_scope_rejected(test_client, client, monkeypatch): + """Per RFC 6749 Section 3.3, when scope is omitted and client.get_allowed_scope() + returns None, the server should fail with invalid_scope. + """ + + def get_allowed_scope_reject(scope): + if scope is None: + return None + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_reject) + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={"grant_type": "client_credentials"}, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_scope" diff --git a/tests/flask/test_oauth2/test_client_registration_endpoint.py b/tests/flask/test_oauth2/test_client_registration_endpoint.py deleted file mode 100644 index 3c987cf13..000000000 --- a/tests/flask/test_oauth2/test_client_registration_endpoint.py +++ /dev/null @@ -1,170 +0,0 @@ -from flask import json -from authlib.jose import jwt -from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint as _ClientRegistrationEndpoint -from tests.util import read_file_path -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): - software_statement_alg_values_supported = ['RS256'] - - def authenticate_token(self, request): - auth_header = request.headers.get('Authorization') - if auth_header: - request.user_id = 1 - return auth_header - - def resolve_public_key(self, request): - return read_file_path('rsa_public.pem') - - def save_client(self, client_info, client_metadata, request): - client = Client( - user_id=request.user_id, - **client_info - ) - client.set_client_metadata(client_metadata) - db.session.add(client) - db.session.commit() - return client - - -class ClientRegistrationTest(TestCase): - def prepare_data(self, endpoint_cls=None, metadata=None): - app = self.app - server = create_authorization_server(app) - if metadata: - server.metadata = metadata - - if endpoint_cls is None: - endpoint_cls = ClientRegistrationEndpoint - server.register_endpoint(endpoint_cls) - - @app.route('/create_client', methods=['POST']) - def create_client(): - return server.create_endpoint_response('client_registration') - - user = User(username='foo') - db.session.add(user) - db.session.commit() - - def test_access_denied(self): - self.prepare_data() - rv = self.client.post('/create_client') - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'access_denied') - - def test_invalid_request(self): - self.prepare_data() - headers = {'Authorization': 'bearer abc'} - rv = self.client.post('/create_client', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_create_client(self): - self.prepare_data() - headers = {'Authorization': 'bearer abc'} - body = { - 'client_name': 'Authlib' - } - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') - - def test_software_statement(self): - payload = {'software_id': 'uuid-123', 'client_name': 'Authlib'} - s = jwt.encode({'alg': 'RS256'}, payload, read_file_path('rsa_private.pem')) - body = { - 'software_statement': s.decode('utf-8'), - } - - self.prepare_data() - headers = {'Authorization': 'bearer abc'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') - - def test_no_public_key(self): - - class ClientRegistrationEndpoint2(ClientRegistrationEndpoint): - def resolve_public_key(self, request): - return None - - payload = {'software_id': 'uuid-123', 'client_name': 'Authlib'} - s = jwt.encode({'alg': 'RS256'}, payload, read_file_path('rsa_private.pem')) - body = { - 'software_statement': s.decode('utf-8'), - } - - self.prepare_data(ClientRegistrationEndpoint2) - headers = {'Authorization': 'bearer abc'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn(resp['error'], 'unapproved_software_statement') - - def test_scopes_supported(self): - metadata = {'scopes_supported': ['profile', 'email']} - self.prepare_data(metadata=metadata) - - headers = {'Authorization': 'bearer abc'} - body = {'scope': 'profile email', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') - - body = {'scope': 'profile email address', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn(resp['error'], 'invalid_client_metadata') - - def test_response_types_supported(self): - metadata = {'response_types_supported': ['code']} - self.prepare_data(metadata=metadata) - - headers = {'Authorization': 'bearer abc'} - body = {'response_types': ['code'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') - - body = {'response_types': ['code', 'token'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn(resp['error'], 'invalid_client_metadata') - - def test_grant_types_supported(self): - metadata = {'grant_types_supported': ['authorization_code', 'password']} - self.prepare_data(metadata=metadata) - - headers = {'Authorization': 'bearer abc'} - body = {'grant_types': ['password'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') - - body = {'grant_types': ['client_credentials'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn(resp['error'], 'invalid_client_metadata') - - def test_token_endpoint_auth_methods_supported(self): - metadata = {'token_endpoint_auth_methods_supported': ['client_secret_basic']} - self.prepare_data(metadata=metadata) - - headers = {'Authorization': 'bearer abc'} - body = {'token_endpoint_auth_method': 'client_secret_basic', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') - - body = {'token_endpoint_auth_method': 'none', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) - resp = json.loads(rv.data) - self.assertIn(resp['error'], 'invalid_client_metadata') diff --git a/tests/flask/test_oauth2/test_client_registration_endpoint_oauth2.py b/tests/flask/test_oauth2/test_client_registration_endpoint_oauth2.py new file mode 100644 index 000000000..6859e40fa --- /dev/null +++ b/tests/flask/test_oauth2/test_client_registration_endpoint_oauth2.py @@ -0,0 +1,253 @@ +import pytest +from flask import json +from joserfc import jwt +from joserfc.jwk import KeySet +from joserfc.jwk import RSAKey + +from authlib.oauth2.rfc7591 import ( + ClientRegistrationEndpoint as _ClientRegistrationEndpoint, +) +from tests.util import read_file_path + +from .models import Client +from .models import db + + +class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): + software_statement_alg_values_supported = ["RS256"] + + def authenticate_token(self, request): + auth_header = request.headers.get("Authorization") + if auth_header: + request.user_id = 1 + return auth_header + + def resolve_public_key(self, request): + return read_file_path("rsa_public.pem") + + def save_client(self, client_info, client_metadata, request): + client = Client(user_id=request.user_id, **client_info) + client.set_client_metadata(client_metadata) + db.session.add(client) + db.session.commit() + return client + + +@pytest.fixture +def metadata(): + return {} + + +@pytest.fixture(autouse=True) +def server(server, app, metadata): + class MyClientRegistration(ClientRegistrationEndpoint): + def get_server_metadata(test_client): + return metadata + + server.register_endpoint(MyClientRegistration) + + @app.route("/create_client", methods=["POST"]) + def create_client(): + return server.create_endpoint_response("client_registration") + + return server + + +def test_access_denied(test_client): + rv = test_client.post("/create_client", json={}) + resp = json.loads(rv.data) + assert resp["error"] == "access_denied" + + +def test_invalid_request(test_client): + headers = {"Authorization": "bearer abc"} + rv = test_client.post("/create_client", json={}, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + +def test_create_client(test_client): + headers = {"Authorization": "bearer abc"} + body = {"client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + +def test_software_statement(test_client): + payload = {"software_id": "uuid-123", "client_name": "Authlib"} + key = RSAKey.import_key(read_file_path("rsa_private.pem")) + software_statement = jwt.encode({"alg": "RS256"}, payload, key) + body = { + "software_statement": software_statement, + } + + headers = {"Authorization": "bearer abc"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + +def test_no_public_key(test_client, server): + class ClientRegistrationEndpoint2(ClientRegistrationEndpoint): + def get_server_metadata(test_client): + return None + + def resolve_public_key(self, request): + return None + + payload = {"software_id": "uuid-123", "client_name": "Authlib"} + key = RSAKey.import_key(read_file_path("rsa_private.pem")) + software_statement = jwt.encode({"alg": "RS256"}, payload, key) + body = { + "software_statement": software_statement, + } + + server._endpoints[ClientRegistrationEndpoint.ENDPOINT_NAME] = [ + ClientRegistrationEndpoint2(server) + ] + + headers = {"Authorization": "bearer abc"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "unapproved_software_statement" + + +def test_scopes_supported(test_client, metadata): + metadata["scopes_supported"] = ["profile", "email"] + + headers = {"Authorization": "bearer abc"} + body = {"scope": "profile email", "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + body = {"scope": "profile email address", "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_response_types_supported(test_client, metadata): + metadata["response_types_supported"] = ["code", "code id_token"] + + headers = {"Authorization": "bearer abc"} + body = {"response_types": ["code"], "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + # The items order should not matter + # Extension response types MAY contain a space-delimited (%x20) list of + # values, where the order of values does not matter (e.g., response + # type "a b" is the same as "b a"). + headers = {"Authorization": "bearer abc"} + body = {"response_types": ["id_token code"], "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + # https://www.rfc-editor.org/rfc/rfc7591.html#section-2 + # If omitted, the default is that the client will use only the "code" + # response type. + body = {"client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + body = {"response_types": ["code", "token"], "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_grant_types_supported(test_client, metadata): + metadata["grant_types_supported"] = ["authorization_code", "password"] + + headers = {"Authorization": "bearer abc"} + body = {"grant_types": ["password"], "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + # https://www.rfc-editor.org/rfc/rfc7591.html#section-2 + # If omitted, the default behavior is that the client will use only + # the "authorization_code" Grant Type. + body = {"client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + body = {"grant_types": ["client_credentials"], "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_token_endpoint_auth_methods_supported(test_client, metadata): + metadata["token_endpoint_auth_methods_supported"] = ["client_secret_basic"] + + headers = {"Authorization": "bearer abc"} + body = { + "token_endpoint_auth_method": "client_secret_basic", + "client_name": "Authlib", + } + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + body = {"token_endpoint_auth_method": "none", "client_name": "Authlib"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_validate_contacts(test_client): + headers = {"Authorization": "bearer abc"} + body = {"client_name": "Authlib", "contacts": "invalid"} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "contacts" in resp["error_description"] + + +def test_validate_jwks(test_client): + headers = {"Authorization": "bearer abc"} + + keyset = KeySet.generate_key_set("oct", 128, count=1) + valid_jwks = keyset.as_dict() + + body = {"client_name": "Authlib", "jwks": valid_jwks} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_name"] == "Authlib" + + # case 1: jwks and jwks_uri both provided + body = { + "client_name": "Authlib", + "jwks": valid_jwks, + "jwks_uri": "http://testserver/jwks", + } + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "jwks" in resp["error_description"] + + # case 2: empty jwks + body = {"client_name": "Authlib", "jwks": {"keys": []}} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "jwks" in resp["error_description"] + + # case 3: invalid jwks + body = {"client_name": "Authlib", "jwks": {"keys": "hello"}} + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert "jwks" in resp["error_description"] diff --git a/tests/flask/test_oauth2/test_client_registration_endpoint_oidc.py b/tests/flask/test_oauth2/test_client_registration_endpoint_oidc.py new file mode 100644 index 000000000..e361d4d3e --- /dev/null +++ b/tests/flask/test_oauth2/test_client_registration_endpoint_oidc.py @@ -0,0 +1,622 @@ +import pytest +from flask import json + +from authlib.oauth2.rfc7591 import ClientMetadataClaims as OAuth2ClientMetadataClaims +from authlib.oauth2.rfc7591 import ( + ClientRegistrationEndpoint as _ClientRegistrationEndpoint, +) +from authlib.oidc.registration import ClientMetadataClaims as OIDCClientMetadataClaims +from tests.util import read_file_path + +from .models import Client +from .models import db + + +class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): + software_statement_alg_values_supported = ["RS256"] + + def authenticate_token(self, request): + auth_header = request.headers.get("Authorization") + if auth_header: + request.user_id = 1 + return auth_header + + def resolve_public_key(self, request): + return read_file_path("rsa_public.pem") + + def save_client(self, client_info, client_metadata, request): + client = Client(user_id=request.user_id, **client_info) + client.set_client_metadata(client_metadata) + db.session.add(client) + db.session.commit() + return client + + +@pytest.fixture +def metadata(): + return {} + + +@pytest.fixture(autouse=True) +def server(server, app, metadata): + class MyClientRegistration(ClientRegistrationEndpoint): + def get_server_metadata(self): + return metadata + + server.register_endpoint( + MyClientRegistration( + claims_classes=[OAuth2ClientMetadataClaims, OIDCClientMetadataClaims] + ) + ) + + @app.route("/create_client", methods=["POST"]) + def create_client(): + return server.create_endpoint_response("client_registration") + + return server + + +def test_application_type(test_client): + # Nominal case + body = { + "application_type": "web", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["application_type"] == "web" + + # Default case + # The default, if omitted, is that any algorithm supported by the OP and the RP MAY be used. + body = { + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["application_type"] == "web" + + # Error case + body = { + "application_type": "invalid", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_token_endpoint_auth_signing_alg_supported(test_client, metadata): + metadata["token_endpoint_auth_signing_alg_values_supported"] = ["RS256", "ES256"] + + # Nominal case + body = { + "token_endpoint_auth_signing_alg": "ES256", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["token_endpoint_auth_signing_alg"] == "ES256" + + # Default case + # The default, if omitted, is that any algorithm supported by the OP and the RP MAY be used. + body = { + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + + # Error case + body = { + "token_endpoint_auth_signing_alg": "RS512", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_subject_types_supported(test_client, metadata): + metadata["subject_types_supported"] = ["public", "pairwise"] + + # Nominal case + body = {"subject_type": "public", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["subject_type"] == "public" + + # Error case + body = {"subject_type": "invalid", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_id_token_signing_alg_values_supported(test_client, metadata): + metadata["id_token_signing_alg_values_supported"] = ["RS256", "ES256"] + + # Default + # The default, if omitted, is RS256. + body = {"client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["id_token_signed_response_alg"] == "RS256" + + # Nominal case + body = {"id_token_signed_response_alg": "ES256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["id_token_signed_response_alg"] == "ES256" + + # Error case + body = {"id_token_signed_response_alg": "RS512", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client_metadata" + + +def test_id_token_signing_alg_values_none(test_client, metadata): + # The value none MUST NOT be used as the ID Token alg value unless the Client uses + # only Response Types that return no ID Token from the Authorization Endpoint + # (such as when only using the Authorization Code Flow). + metadata["id_token_signing_alg_values_supported"] = ["none", "RS256", "ES256"] + + # Nominal case + body = { + "id_token_signed_response_alg": "none", + "client_name": "Authlib", + "response_type": "code", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["id_token_signed_response_alg"] == "none" + + # Error case + body = { + "id_token_signed_response_alg": "none", + "client_name": "Authlib", + "response_type": "id_token", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client_metadata" + + +def test_id_token_encryption_alg_values_supported(test_client, metadata): + metadata["id_token_encryption_alg_values_supported"] = ["RS256", "ES256"] + + # Default case + body = {"client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert "id_token_encrypted_response_enc" not in resp + + # If id_token_encrypted_response_alg is specified, the default + # id_token_encrypted_response_enc value is A128CBC-HS256. + body = {"id_token_encrypted_response_alg": "RS256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["id_token_encrypted_response_enc"] == "A128CBC-HS256" + + # Nominal case + body = {"id_token_encrypted_response_alg": "ES256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["id_token_encrypted_response_alg"] == "ES256" + + # Error case + body = {"id_token_encrypted_response_alg": "RS512", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_id_token_encryption_enc_values_supported(test_client, metadata): + metadata["id_token_encryption_enc_values_supported"] = ["A128CBC-HS256", "A256GCM"] + + # Nominal case + body = { + "id_token_encrypted_response_alg": "RS256", + "id_token_encrypted_response_enc": "A256GCM", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["id_token_encrypted_response_alg"] == "RS256" + assert resp["id_token_encrypted_response_enc"] == "A256GCM" + + # Error case: missing id_token_encrypted_response_alg + body = {"id_token_encrypted_response_enc": "A256GCM", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + # Error case: alg not in server metadata + body = {"id_token_encrypted_response_enc": "A128GCM", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_userinfo_signing_alg_values_supported(test_client, metadata): + metadata["userinfo_signing_alg_values_supported"] = ["RS256", "ES256"] + + # Nominal case + body = {"userinfo_signed_response_alg": "ES256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["userinfo_signed_response_alg"] == "ES256" + + # Error case + body = {"userinfo_signed_response_alg": "RS512", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_userinfo_encryption_alg_values_supported(test_client, metadata): + metadata["userinfo_encryption_alg_values_supported"] = ["RS256", "ES256"] + + # Nominal case + body = {"userinfo_encrypted_response_alg": "ES256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["userinfo_encrypted_response_alg"] == "ES256" + + # Error case + body = {"userinfo_encrypted_response_alg": "RS512", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_userinfo_encryption_enc_values_supported(test_client, metadata): + metadata["userinfo_encryption_enc_values_supported"] = ["A128CBC-HS256", "A256GCM"] + + # Default case + body = {"client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert "userinfo_encrypted_response_enc" not in resp + + # If userinfo_encrypted_response_alg is specified, the default + # userinfo_encrypted_response_enc value is A128CBC-HS256. + body = {"userinfo_encrypted_response_alg": "RS256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["userinfo_encrypted_response_enc"] == "A128CBC-HS256" + + # Nominal case + body = { + "userinfo_encrypted_response_alg": "RS256", + "userinfo_encrypted_response_enc": "A256GCM", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["userinfo_encrypted_response_alg"] == "RS256" + assert resp["userinfo_encrypted_response_enc"] == "A256GCM" + + # Error case: no userinfo_encrypted_response_alg + body = {"userinfo_encrypted_response_enc": "A256GCM", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + # Error case: alg not in server metadata + body = {"userinfo_encrypted_response_enc": "A128GCM", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_acr_values_supported(test_client, metadata): + metadata["acr_values_supported"] = [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze", + ] + + # Nominal case + body = { + "default_acr_values": ["urn:mace:incommon:iap:silver"], + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["default_acr_values"] == ["urn:mace:incommon:iap:silver"] + + # Error case + body = { + "default_acr_values": [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:gold", + ], + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_request_object_signing_alg_values_supported(test_client, metadata): + metadata["request_object_signing_alg_values_supported"] = ["RS256", "ES256"] + + # Nominal case + body = {"request_object_signing_alg": "ES256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["request_object_signing_alg"] == "ES256" + + # Error case + body = {"request_object_signing_alg": "RS512", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_request_object_encryption_alg_values_supported(test_client, metadata): + metadata["request_object_encryption_alg_values_supported"] = ["RS256", "ES256"] + + # Nominal case + body = { + "request_object_encryption_alg": "ES256", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["request_object_encryption_alg"] == "ES256" + + # Error case + body = { + "request_object_encryption_alg": "RS512", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_request_object_encryption_enc_values_supported(test_client, metadata): + metadata["request_object_encryption_enc_values_supported"] = [ + "A128CBC-HS256", + "A256GCM", + ] + + # Default case + body = {"client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert "request_object_encryption_enc" not in resp + + # If request_object_encryption_alg is specified, the default + # request_object_encryption_enc value is A128CBC-HS256. + body = {"request_object_encryption_alg": "RS256", "client_name": "Authlib"} + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["request_object_encryption_enc"] == "A128CBC-HS256" + + # Nominal case + body = { + "request_object_encryption_alg": "RS256", + "request_object_encryption_enc": "A256GCM", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["request_object_encryption_alg"] == "RS256" + assert resp["request_object_encryption_enc"] == "A256GCM" + + # Error case: missing request_object_encryption_alg + body = { + "request_object_encryption_enc": "A256GCM", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + # Error case: alg not in server metadata + body = { + "request_object_encryption_enc": "A128GCM", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_require_auth_time(test_client): + # Default case + body = { + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["require_auth_time"] is False + + # Nominal case + body = { + "require_auth_time": True, + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["require_auth_time"] is True + + # Error case + body = { + "require_auth_time": "invalid", + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" + + +def test_redirect_uri(test_client): + """RFC6749 indicate that fragments are forbidden in redirect_uri. + + The redirection endpoint URI MUST be an absolute URI as defined by + [RFC3986] Section 4.3. [...] The endpoint URI MUST NOT include a + fragment component. + + https://www.rfc-editor.org/rfc/rfc6749#section-3.1.2 + """ + # Nominal case + body = { + "redirect_uris": ["https://client.test"], + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert "client_id" in resp + assert resp["client_name"] == "Authlib" + assert resp["redirect_uris"] == ["https://client.test"] + + # Error case + body = { + "redirect_uris": ["https://client.test#fragment"], + "client_name": "Authlib", + } + rv = test_client.post( + "/create_client", json=body, headers={"Authorization": "bearer abc"} + ) + resp = json.loads(rv.data) + assert resp["error"] in "invalid_client_metadata" diff --git a/tests/flask/test_oauth2/test_code_challenge.py b/tests/flask/test_oauth2/test_code_challenge.py index f3c257958..77014b0fc 100644 --- a/tests/flask/test_oauth2/test_code_challenge.py +++ b/tests/flask/test_oauth2/test_code_challenge.py @@ -1,225 +1,312 @@ +import pytest from flask import json + from authlib.common.security import generate_token -from authlib.common.urls import urlparse, url_decode +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse from authlib.oauth2.rfc6749 import grants -from authlib.oauth2.rfc7636 import ( - CodeChallenge as _CodeChallenge, - create_s256_code_challenge, -) -from .models import db, User, Client -from .models import CodeGrantMixin, save_authorization_code -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oauth2.rfc7636 import CodeChallenge as _CodeChallenge +from authlib.oauth2.rfc7636 import create_s256_code_challenge + +from .models import CodeGrantMixin +from .models import save_authorization_code +from .oauth2_server import create_basic_header + +authorize_url = "/oauth/authorize?response_type=code&client_id=client-id" class AuthorizationCodeGrant(CodeGrantMixin, grants.AuthorizationCodeGrant): - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] def save_authorization_code(self, code, request): return save_authorization_code(code, request) class CodeChallenge(_CodeChallenge): - SUPPORTED_CODE_CHALLENGE_METHOD = ['plain', 'S256', 'S128'] - - -class CodeChallengeTest(TestCase): - def prepare_data(self, token_endpoint_auth_method='none'): - server = create_authorization_server(self.app) - server.register_grant( - AuthorizationCodeGrant, - [CodeChallenge(required=True)] - ) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - - client_secret = '' - if token_endpoint_auth_method != 'none': - client_secret = 'code-secret' - - client = Client( - user_id=user.id, - client_id='code-client', - client_secret=client_secret, - ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b'], - 'scope': 'profile address', - 'token_endpoint_auth_method': token_endpoint_auth_method, - 'response_types': ['code'], - 'grant_types': ['authorization_code'], - }) - self.authorize_url = ( - '/oauth/authorize?response_type=code' - '&client_id=code-client' - ) - db.session.add(client) - db.session.commit() - - def test_missing_code_challenge(self): - self.prepare_data() - rv = self.client.get(self.authorize_url + '&code_challenge_method=plain') - self.assertIn(b'Missing', rv.data) - - def test_has_code_challenge(self): - self.prepare_data() - rv = self.client.get(self.authorize_url + '&code_challenge=abc') - self.assertEqual(rv.data, b'ok') - - def test_invalid_code_challenge_method(self): - self.prepare_data() - suffix = '&code_challenge=abc&code_challenge_method=invalid' - rv = self.client.get(self.authorize_url + suffix) - self.assertIn(b'Unsupported', rv.data) - - def test_supported_code_challenge_method(self): - self.prepare_data() - suffix = '&code_challenge=abc&code_challenge_method=plain' - rv = self.client.get(self.authorize_url + suffix) - self.assertEqual(rv.data, b'ok') - - def test_trusted_client_without_code_challenge(self): - self.prepare_data('client_secret_basic') - rv = self.client.get(self.authorize_url) - self.assertEqual(rv.data, b'ok') - - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_missing_code_verifier(self): - self.prepare_data() - url = self.authorize_url + '&code_challenge=foo' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('Missing', resp['error_description']) - - def test_trusted_client_missing_code_verifier(self): - self.prepare_data('client_secret_basic') - url = self.authorize_url + '&code_challenge=foo' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('Missing', resp['error_description']) - - def test_plain_code_challenge_invalid(self): - self.prepare_data() - url = self.authorize_url + '&code_challenge=foo' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'code_verifier': 'bar', - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('Invalid', resp['error_description']) - - def test_plain_code_challenge_failed(self): - self.prepare_data() - url = self.authorize_url + '&code_challenge=foo' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'code_verifier': generate_token(48), - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('failed', resp['error_description']) - - def test_plain_code_challenge_success(self): - self.prepare_data() - code_verifier = generate_token(48) - url = self.authorize_url + '&code_challenge=' + code_verifier - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'code_verifier': code_verifier, - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_s256_code_challenge_success(self): - self.prepare_data() - code_verifier = generate_token(48) - code_challenge = create_s256_code_challenge(code_verifier) - url = self.authorize_url + '&code_challenge=' + code_challenge - url += '&code_challenge_method=S256' - - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'code_verifier': code_verifier, - 'client_id': 'code-client', - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_not_implemented_code_challenge_method(self): - self.prepare_data() - url = self.authorize_url + '&code_challenge=foo' - url += '&code_challenge_method=S128' - - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - code = params['code'] - self.assertRaises( - RuntimeError, self.client.post, '/oauth/token', + SUPPORTED_CODE_CHALLENGE_METHOD = ["plain", "S256", "S128"] + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_missing_code_challenge(test_client): + rv = test_client.get(authorize_url + "&code_challenge_method=plain") + assert "Missing" in rv.location + + +def test_has_code_challenge(test_client): + rv = test_client.get( + authorize_url + "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s" + ) + assert rv.data == b"ok" + + +def test_invalid_code_challenge(test_client): + rv = test_client.get( + authorize_url + "&code_challenge=abc&code_challenge_method=plain" + ) + assert "Invalid" in rv.location + + +def test_invalid_code_challenge_method(test_client): + suffix = "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s&code_challenge_method=invalid" + rv = test_client.get(authorize_url + suffix) + assert "Unsupported" in rv.location + + +def test_supported_code_challenge_method(test_client): + suffix = "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s&code_challenge_method=plain" + rv = test_client.get(authorize_url + suffix) + assert rv.data == b"ok" + + +def test_trusted_client_without_code_challenge(test_client, db, client): + client.client_secret = "client-secret" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.get(authorize_url) + assert rv.data == b"ok" + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_code_verifier_without_code_challenge(test_client, db, client): + """RFC 9700 Section 4.8.2: the authorization server MUST ensure that + if there was no code_challenge in the authorization request, a + request to the token endpoint containing a code_verifier is rejected.""" + client.client_secret = "client-secret" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.get(authorize_url) + assert rv.data == b"ok" + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "code_verifier": generate_token(48), + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + +def test_missing_code_verifier(test_client): + url = authorize_url + "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s" + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "Missing" in resp["error_description"] + + +def test_trusted_client_missing_code_verifier(test_client, db, client): + client.client_secret = "client-secret" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + + url = authorize_url + "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s" + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "Missing" in resp["error_description"] + + +def test_plain_code_challenge_invalid(test_client): + url = authorize_url + "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s" + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "code_verifier": "bar", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "Invalid" in resp["error_description"] + + +def test_plain_code_challenge_failed(test_client): + url = authorize_url + "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s" + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "code_verifier": generate_token(48), + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "failed" in resp["error_description"] + + +def test_plain_code_challenge_success(test_client): + code_verifier = generate_token(48) + url = authorize_url + "&code_challenge=" + code_verifier + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "code_verifier": code_verifier, + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_s256_code_challenge_success(test_client): + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + url = authorize_url + "&code_challenge=" + code_challenge + url += "&code_challenge_method=S256" + + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "code_verifier": code_verifier, + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_not_implemented_code_challenge_method(test_client): + url = authorize_url + "&code_challenge=Zhs2POMonIVVHZteWfoU7cSXQSm0YjghikFGJSDI2_s" + url += "&code_challenge_method=S128" + + rv = test_client.post(url, data={"user_id": "1"}) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + with pytest.raises(RuntimeError): + test_client.post( + "/oauth/token", data={ - 'grant_type': 'authorization_code', - 'code': code, - 'code_verifier': generate_token(48), - 'client_id': 'code-client', - } + "grant_type": "authorization_code", + "code": code, + "code_verifier": generate_token(48), + "client_id": "client-id", + }, ) diff --git a/tests/flask/test_oauth2/test_device_code_grant.py b/tests/flask/test_oauth2/test_device_code_grant.py index 6f135db33..d69d6f90a 100644 --- a/tests/flask/test_oauth2/test_device_code_grant.py +++ b/tests/flask/test_oauth2/test_device_code_grant.py @@ -1,48 +1,52 @@ import time + +import pytest from flask import json + from authlib.oauth2.rfc8628 import ( DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, - DeviceCodeGrant as _DeviceCodeGrant, - DeviceCredentialDict, ) -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oauth2.rfc8628 import DeviceCodeGrant as _DeviceCodeGrant +from authlib.oauth2.rfc8628 import DeviceCredentialDict +from .models import Client +from .models import User +from .models import db device_credentials = { - 'valid-device': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'code', + "valid-device": { + "client_id": "client-id", + "expires_in": 1800, + "user_code": "code", }, - 'expired-token': { - 'client_id': 'client', - 'expires_in': -100, - 'user_code': 'none', + "expired-token": { + "client_id": "client-id", + "expires_in": -100, + "user_code": "none", }, - 'invalid-client': { - 'client_id': 'invalid', - 'expires_in': 1800, - 'user_code': 'none', + "invalid-client": { + "client_id": "invalid", + "expires_in": 1800, + "user_code": "none", }, - 'denied-code': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'denied', + "denied-code": { + "client_id": "client-id", + "expires_in": 1800, + "user_code": "denied", }, - 'grant-code': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'code', + "grant-code": { + "client_id": "client-id", + "expires_in": 1800, + "user_code": "code", + }, + "pending-code": { + "client_id": "client-id", + "expires_in": 1800, + "user_code": "none", }, - 'pending-code': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'none', - } } + class DeviceCodeGrant(_DeviceCodeGrant): def query_device_credential(self, device_code): data = device_credentials.get(device_code) @@ -50,188 +54,211 @@ def query_device_credential(self, device_code): return None now = int(time.time()) - data['expires_at'] = now + data['expires_in'] - data['device_code'] = device_code - data['scope'] = 'profile' - data['interval'] = 5 - data['verification_uri'] = 'https://example.com/activate' + data["expires_at"] = now + data["expires_in"] + data["device_code"] = device_code + data["scope"] = "profile" + data["interval"] = 5 + data["verification_uri"] = "https://resource.test/activate" return DeviceCredentialDict(data) def query_user_grant(self, user_code): - if user_code == 'code': - return User.query.get(1), True - if user_code == 'denied': - return User.query.get(1), False + if user_code == "code": + return db.session.get(User, 1), True + if user_code == "denied": + return db.session.get(User, 1), False return None - def should_slow_down(self, credential, now): + def should_slow_down(self, credential): return False -class DeviceCodeGrantTest(TestCase): - def create_server(self): - server = create_authorization_server(self.app) - server.register_grant(DeviceCodeGrant) - self.server = server - return server - - def prepare_data(self, grant_type=DeviceCodeGrant.GRANT_TYPE): - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='client', - client_secret='secret', - ) - client.set_client_metadata({ - 'redirect_uris': ['http://localhost/authorized'], - 'scope': 'profile', - 'grant_types': [grant_type], - }) - db.session.add(client) - db.session.commit() - - def test_invalid_request(self): - self.create_server() - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'missing', - 'client_id': 'client', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_unauthorized_client(self): - self.create_server() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', - 'client_id': 'invalid', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - self.prepare_data(grant_type='password') - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', - 'client_id': 'client', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - def test_invalid_client(self): - self.create_server() - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'invalid-client', - 'client_id': 'invalid', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_expired_token(self): - self.create_server() - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'expired-token', - 'client_id': 'client', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'expired_token') - - def test_denied_by_user(self): - self.create_server() - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'denied-code', - 'client_id': 'client', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'access_denied') - - def test_authorization_pending(self): - self.create_server() - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'pending-code', - 'client_id': 'client', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'authorization_pending') - - def test_get_access_token(self): - self.create_server() - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'grant-code', - 'client_id': 'client', - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): def get_verification_uri(self): - return 'https://example.com/activate' + return "https://resource.test/activate" def save_device_credential(self, client_id, scope, data): pass -class DeviceAuthorizationEndpointTest(TestCase): - def create_server(self): - server = create_authorization_server(self.app) - server.register_endpoint(DeviceAuthorizationEndpoint) - self.server = server - - @self.app.route('/device_authorize', methods=['POST']) - def device_authorize(): - name = DeviceAuthorizationEndpoint.ENDPOINT_NAME - return server.create_endpoint_response(name) - - return server - - def test_missing_client_id(self): - self.create_server() - rv = self.client.post('/device_authorize', data={ - 'scope': 'profile' - }) - self.assertEqual(rv.status_code, 400) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_create_authorization_response(self): - self.create_server() - rv = self.client.post('/device_authorize', data={ - 'client_id': 'client', - }) - self.assertEqual(rv.status_code, 200) - resp = json.loads(rv.data) - self.assertIn('device_code', resp) - self.assertIn('user_code', resp) - self.assertEqual(resp['verification_uri'], 'https://example.com/activate') - self.assertEqual( - resp['verification_uri_complete'], - 'https://example.com/activate?user_code=' + resp['user_code'] - ) +@pytest.fixture(autouse=True) +def server(server, app): + server.register_grant(DeviceCodeGrant) + + @app.route("/device_authorize", methods=["POST"]) + def device_authorize(): + name = DeviceAuthorizationEndpoint.ENDPOINT_NAME + return server.create_endpoint_response(name) + + server.register_endpoint(DeviceAuthorizationEndpoint) + + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/authorized"], + "scope": "profile", + "grant_types": [DeviceCodeGrant.GRANT_TYPE], + "token_endpoint_auth_method": "none", + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_invalid_request(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "client_id": "test", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "missing", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + +def test_unauthorized_client(test_client, db, client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "valid-device", + "client_id": "invalid", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/authorized"], + "scope": "profile", + "grant_types": ["password"], + "token_endpoint_auth_method": "none", + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "valid-device", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unauthorized_client" + + +def test_invalid_client(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "invalid-client", + "client_id": "invalid", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_expired_token(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "expired-token", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "expired_token" + + +def test_denied_by_user(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "denied-code", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "access_denied" + + +def test_authorization_pending(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "pending-code", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "authorization_pending" + + +def test_get_access_token(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "grant-code", + "client_id": "client-id", + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_missing_client_id(test_client): + rv = test_client.post("/device_authorize", data={"scope": "profile"}) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_create_authorization_response(test_client): + client = Client( + user_id=1, + client_id="client", + client_secret="secret", + ) + db.session.add(client) + db.session.commit() + rv = test_client.post( + "/device_authorize", + data={ + "client_id": "client-id", + }, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert "device_code" in resp + assert "user_code" in resp + assert resp["verification_uri"] == "https://resource.test/activate" + assert ( + resp["verification_uri_complete"] + == "https://resource.test/activate?user_code=" + resp["user_code"] + ) diff --git a/tests/flask/test_oauth2/test_end_session.py b/tests/flask/test_oauth2/test_end_session.py new file mode 100644 index 000000000..542331e00 --- /dev/null +++ b/tests/flask/test_oauth2/test_end_session.py @@ -0,0 +1,502 @@ +"""Tests for RP-Initiated Logout endpoint.""" + +import pytest +from flask import request +from joserfc import jwt +from joserfc.jwk import KeySet + +from authlib.oauth2.rfc6749.errors import OAuth2Error +from authlib.oidc.rpinitiated import EndSessionEndpoint +from authlib.oidc.rpinitiated import EndSessionRequest +from tests.util import read_file_path + +from .models import Client + + +def create_id_token(claims): + """Create a signed ID token for testing.""" + header = {"alg": "RS256"} + jwks = read_file_path("jwks_private.json") + key = KeySet.import_key_set(jwks) + return jwt.encode(header, claims, key) + + +class MyEndSessionEndpoint(EndSessionEndpoint): + """Test endpoint implementation.""" + + def get_server_jwks(self): + return read_file_path("jwks_public.json") + + def is_post_logout_redirect_uri_legitimate( + self, request, post_logout_redirect_uri, client, logout_hint + ): + return True + + def end_session(self, end_session_request): + pass + + +class DefaultLegitimacyEndpoint(MyEndSessionEndpoint): + """Endpoint that uses default is_post_logout_redirect_uri_legitimate.""" + + def is_post_logout_redirect_uri_legitimate( + self, request, post_logout_redirect_uri, client, logout_hint + ): + # Call parent's default implementation + return EndSessionEndpoint.is_post_logout_redirect_uri_legitimate( + self, request, post_logout_redirect_uri, client, logout_hint + ) + + +class ErrorRaisingEndpoint(MyEndSessionEndpoint): + """Endpoint that raises error in end_session.""" + + def end_session(self, end_session_request): + from authlib.oauth2.rfc6749.errors import InvalidRequestError + + raise InvalidRequestError("Session termination failed") + + +@pytest.fixture +def endpoint_server(server, app): + """Server with EndSessionEndpoint registered.""" + endpoint = MyEndSessionEndpoint() + server.register_endpoint(endpoint) + + @app.route("/logout", methods=["GET", "POST"]) + def logout(): + return server.create_endpoint_response("end_session") or "Logged out" + + @app.route("/logout_interactive", methods=["GET", "POST"]) + def logout_interactive(): + try: + req = server.validate_endpoint_request("end_session") + except OAuth2Error as error: + return server.handle_error_response(None, error) + + if req.needs_confirmation and request.method == "GET": + return "Confirm logout", 200 + + return server.create_endpoint_response("end_session", req) or "Logged out" + + return server + + +@pytest.fixture +def client_model(user, db): + """Create a test client.""" + client = Client( + user_id=user.id, + client_id="client-id", + client_secret="client-secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/callback"], + "post_logout_redirect_uris": [ + "https://client.test/logout", + "https://client.test/logged-out", + ], + } + ) + db.session.add(client) + db.session.commit() + yield client + db.session.delete(client) + + +@pytest.fixture +def valid_id_token(): + """Create a valid ID token.""" + return create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": "client-id", + "exp": 9999999999, + "iat": 1000000000, + } + ) + + +# EndSessionRequest tests + + +def test_needs_confirmation_without_id_token(): + """needs_confirmation is True when id_token_claims is None.""" + req = EndSessionRequest(request=None, client=None, id_token_claims=None) + assert req.needs_confirmation is True + + +def test_needs_confirmation_with_id_token(): + """needs_confirmation is False when id_token_claims is present.""" + req = EndSessionRequest( + request=None, client=None, id_token_claims={"sub": "user-1"} + ) + assert req.needs_confirmation is False + + +# Non-interactive mode tests + + +def test_logout_with_valid_id_token( + test_client, endpoint_server, client_model, valid_id_token +): + """Logout with valid id_token_hint succeeds.""" + rv = test_client.get(f"/logout?id_token_hint={valid_id_token}") + + assert rv.status_code == 200 + assert rv.data == b"Logged out" + + +def test_logout_with_redirect_uri( + test_client, endpoint_server, client_model, valid_id_token +): + """Logout with valid redirect URI redirects.""" + rv = test_client.get( + f"/logout?id_token_hint={valid_id_token}" + "&post_logout_redirect_uri=https://client.test/logout" + ) + + assert rv.status_code == 302 + assert rv.headers["Location"] == "https://client.test/logout" + + +def test_logout_with_redirect_uri_and_state( + test_client, endpoint_server, client_model, valid_id_token +): + """State parameter is appended to redirect URI.""" + rv = test_client.get( + f"/logout?id_token_hint={valid_id_token}" + "&post_logout_redirect_uri=https://client.test/logout" + "&state=xyz123" + ) + + assert rv.status_code == 302 + assert rv.headers["Location"] == "https://client.test/logout?state=xyz123" + + +def test_logout_without_id_token(test_client, endpoint_server, client_model): + """Logout without id_token_hint succeeds in non-interactive mode.""" + rv = test_client.get("/logout") + + assert rv.status_code == 200 + assert rv.data == b"Logged out" + + +def test_invalid_redirect_uri_ignored( + test_client, endpoint_server, client_model, valid_id_token +): + """Unregistered redirect URI results in no redirect.""" + rv = test_client.get( + f"/logout?id_token_hint={valid_id_token}" + "&post_logout_redirect_uri=https://attacker.test/logout" + ) + + assert rv.status_code == 200 + assert rv.data == b"Logged out" + + +def test_post_with_form_data( + test_client, endpoint_server, client_model, valid_id_token +): + """POST with form-encoded data works.""" + rv = test_client.post( + "/logout", + data={ + "id_token_hint": valid_id_token, + "post_logout_redirect_uri": "https://client.test/logout", + "state": "abc", + }, + ) + + assert rv.status_code == 302 + assert rv.headers["Location"] == "https://client.test/logout?state=abc" + + +# Interactive mode tests + + +def test_confirmation_shown_without_id_token( + test_client, endpoint_server, client_model +): + """Without id_token_hint, confirmation page is shown on GET.""" + rv = test_client.get("/logout_interactive") + + assert rv.status_code == 200 + assert rv.data == b"Confirm logout" + + +def test_confirmation_bypassed_with_id_token( + test_client, endpoint_server, client_model, valid_id_token +): + """With valid id_token_hint, no confirmation needed.""" + rv = test_client.get(f"/logout_interactive?id_token_hint={valid_id_token}") + + assert rv.status_code == 200 + assert rv.data == b"Logged out" + + +def test_post_executes_logout(test_client, endpoint_server, client_model): + """POST request executes logout even without id_token_hint.""" + rv = test_client.post("/logout_interactive") + + assert rv.status_code == 200 + assert rv.data == b"Logged out" + + +def test_redirect_preserved_after_confirmation( + test_client, endpoint_server, client_model +): + """Redirect URI is used after POST confirmation.""" + rv = test_client.post( + "/logout_interactive", + data={ + "client_id": "client-id", + "post_logout_redirect_uri": "https://client.test/logout", + }, + ) + + assert rv.status_code == 302 + assert rv.headers["Location"] == "https://client.test/logout" + + +# Validation tests + + +def test_client_id_mismatch_error( + test_client, endpoint_server, client_model, valid_id_token +): + """client_id not matching aud claim returns error.""" + rv = test_client.get( + f"/logout?id_token_hint={valid_id_token}&client_id=other-client" + ) + + assert rv.status_code == 400 + assert rv.json["error"] == "invalid_request" + assert "'client_id' does not match 'aud' claim" in rv.json["error_description"] + + +def test_invalid_jwt_error(test_client, endpoint_server, client_model): + """Invalid JWT returns error.""" + rv = test_client.get("/logout?id_token_hint=invalid.jwt.token") + + assert rv.status_code == 400 + assert rv.json["error"] == "invalid_request" + + +def test_client_id_matches_aud_list(test_client, endpoint_server, client_model): + """client_id matches when aud is a list containing it.""" + id_token = create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": ["client-id", "other-client"], + "exp": 9999999999, + "iat": 1000000000, + } + ) + rv = test_client.get(f"/logout?id_token_hint={id_token}&client_id=client-id") + + assert rv.status_code == 200 + + +def test_client_id_not_in_aud_list_error(test_client, endpoint_server, client_model): + """client_id not in aud list returns error.""" + id_token = create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": ["other-client-1", "other-client-2"], + "exp": 9999999999, + "iat": 1000000000, + } + ) + rv = test_client.get(f"/logout?id_token_hint={id_token}&client_id=client-id") + + assert rv.status_code == 400 + assert rv.json["error"] == "invalid_request" + + +# Token expiration tests + + +def test_expired_id_token_accepted(test_client, endpoint_server, client_model): + """Expired ID tokens are accepted per the specification.""" + expired_token = create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": "client-id", + "exp": 1, # Expired in 1970 + "iat": 0, + } + ) + rv = test_client.get(f"/logout?id_token_hint={expired_token}") + + assert rv.status_code == 200 + assert rv.data == b"Logged out" + + +def test_token_with_future_nbf_rejected(test_client, endpoint_server, client_model): + """Token with nbf in the future is rejected.""" + token = create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": "client-id", + "exp": 9999999999, + "iat": 0, + "nbf": 9999999999, # Not valid until far future + } + ) + rv = test_client.get(f"/logout?id_token_hint={token}") + + assert rv.status_code == 400 + assert rv.json["error"] == "invalid_request" + + +# Client resolution tests + + +def test_client_resolved_from_single_aud( + test_client, endpoint_server, client_model, valid_id_token +): + """Client is resolved from single aud claim.""" + rv = test_client.get( + f"/logout?id_token_hint={valid_id_token}" + "&post_logout_redirect_uri=https://client.test/logout" + ) + + assert rv.status_code == 302 + assert rv.headers["Location"] == "https://client.test/logout" + + +def test_client_not_resolved_from_aud_list(test_client, endpoint_server, client_model): + """Client is not resolved from aud list (ambiguous).""" + id_token = create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": ["client-id", "other-client"], + "exp": 9999999999, + "iat": 1000000000, + } + ) + rv = test_client.get( + f"/logout?id_token_hint={id_token}" + "&post_logout_redirect_uri=https://client.test/logout" + ) + + assert rv.status_code == 200 + assert rv.data == b"Logged out" # No redirect + + +def test_client_resolved_with_explicit_client_id( + test_client, endpoint_server, client_model +): + """Client is resolved when client_id is provided explicitly.""" + id_token = create_id_token( + { + "iss": "https://provider.test", + "sub": "user-1", + "aud": ["client-id", "other-client"], + "exp": 9999999999, + "iat": 1000000000, + } + ) + rv = test_client.get( + f"/logout?id_token_hint={id_token}" + "&client_id=client-id" + "&post_logout_redirect_uri=https://client.test/logout" + ) + + assert rv.status_code == 302 + assert rv.headers["Location"] == "https://client.test/logout" + + +def test_redirect_requires_client(test_client, endpoint_server, client_model): + """Redirect URI without resolvable client is ignored.""" + rv = test_client.get("/logout?post_logout_redirect_uri=https://client.test/logout") + + assert rv.status_code == 200 + assert rv.data == b"Logged out" # No redirect + + +def test_interactive_mode_error_handling(test_client, endpoint_server, client_model): + """Error during validation returns error response in interactive mode.""" + rv = test_client.get("/logout_interactive?id_token_hint=invalid.jwt.token") + + assert rv.status_code == 400 + assert rv.json["error"] == "invalid_request" + + +def test_validate_unknown_endpoint(server): + """validate_endpoint_request with unknown endpoint raises RuntimeError.""" + with pytest.raises(RuntimeError, match="There is no 'unknown' endpoint"): + server.validate_endpoint_request("unknown") + + +def test_create_endpoint_response_unknown_endpoint(server): + """create_endpoint_response with unknown endpoint raises RuntimeError.""" + with pytest.raises(RuntimeError, match="There is no 'unknown' endpoint"): + server.create_endpoint_response("unknown") + + +def test_default_is_post_logout_redirect_uri_legitimate( + server, app, test_client, client_model, valid_id_token +): + """Default is_post_logout_redirect_uri_legitimate returns False.""" + endpoint = DefaultLegitimacyEndpoint() + server.register_endpoint(endpoint) + + @app.route("/logout_default", methods=["GET", "POST"]) + def logout_default(): + return server.create_endpoint_response("end_session") or "Logged out" + + # Without id_token_hint, redirect should be ignored (default returns False) + rv = test_client.get( + "/logout_default?" + "client_id=client-id" + "&post_logout_redirect_uri=https://client.test/logout" + ) + assert rv.status_code == 200 + assert rv.data == b"Logged out" # No redirect + + +def test_create_endpoint_response_with_validated_request_error( + server, app, test_client, client_model, valid_id_token +): + """Error in create_response with validated request returns error response.""" + endpoint = ErrorRaisingEndpoint() + server.register_endpoint(endpoint) + + @app.route("/logout_error", methods=["GET", "POST"]) + def logout_error(): + req = server.validate_endpoint_request("end_session") + return server.create_endpoint_response("end_session", req) + + rv = test_client.get(f"/logout_error?id_token_hint={valid_id_token}") + assert rv.status_code == 400 + assert rv.json["error"] == "invalid_request" + assert "Session termination failed" in rv.json["error_description"] + + +def test_ui_locales_extracted(server, app, test_client, client_model, valid_id_token): + """ui_locales parameter is extracted and available in EndSessionRequest.""" + endpoint = MyEndSessionEndpoint() + server.register_endpoint(endpoint) + + captured = {} + + @app.route("/logout_locales", methods=["GET", "POST"]) + def logout_locales(): + req = server.validate_endpoint_request("end_session") + captured["ui_locales"] = req.ui_locales + return server.create_endpoint_response("end_session", req) or "Logged out" + + rv = test_client.get( + f"/logout_locales?id_token_hint={valid_id_token}&ui_locales=fr-FR%20en" + ) + assert rv.status_code == 200 + assert captured["ui_locales"] == "fr-FR en" diff --git a/tests/flask/test_oauth2/test_implicit_grant.py b/tests/flask/test_oauth2/test_implicit_grant.py index fa0ce7615..802f34792 100644 --- a/tests/flask/test_oauth2/test_implicit_grant.py +++ b/tests/flask/test_oauth2/test_implicit_grant.py @@ -1,82 +1,144 @@ +import pytest + from authlib.oauth2.rfc6749.grants import ImplicitGrant -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class ImplicitTest(TestCase): - def prepare_data(self, is_confidential=False, response_type='token'): - server = create_authorization_server(self.app) - server.register_grant(ImplicitGrant) - self.server = server - - user = User(username='foo') - db.session.add(user) - db.session.commit() - if is_confidential: - client_secret = 'implicit-secret' - token_endpoint_auth_method = 'client_secret_basic' - else: - client_secret = '' - token_endpoint_auth_method = 'none' - - client = Client( - user_id=user.id, - client_id='implicit-client', - client_secret=client_secret, - ) - client.set_client_metadata({ - 'redirect_uris': ['http://localhost/authorized'], - 'scope': 'profile', - 'response_types': [response_type], - 'grant_types': ['implicit'], - 'token_endpoint_auth_method': token_endpoint_auth_method, - }) - self.authorize_url = ( - '/oauth/authorize?response_type=token' - '&client_id=implicit-client' - ) - db.session.add(client) - db.session.commit() - - def test_get_authorize(self): - self.prepare_data() - rv = self.client.get(self.authorize_url) - self.assertEqual(rv.data, b'ok') - - def test_confidential_client(self): - self.prepare_data(True) - rv = self.client.get(self.authorize_url) - self.assertIn(b'invalid_client', rv.data) - - def test_unsupported_client(self): - self.prepare_data(response_type='code') - rv = self.client.get(self.authorize_url) - self.assertIn(b'unauthorized_client', rv.data) - - def test_invalid_authorize(self): - self.prepare_data() - rv = self.client.post(self.authorize_url) - self.assertIn('#error=access_denied', rv.location) - - self.server.metadata = {'scopes_supported': ['profile']} - rv = self.client.post(self.authorize_url + '&scope=invalid') - self.assertIn('#error=invalid_scope', rv.location) - - def test_authorize_token(self): - self.prepare_data() - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('access_token=', rv.location) - - url = self.authorize_url + '&state=bar&scope=profile' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('access_token=', rv.location) - self.assertIn('state=bar', rv.location) - self.assertIn('scope=profile', rv.location) - - def test_token_generator(self): - m = 'tests.flask.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - self.prepare_data() - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('access_token=i-implicit.1.', rv.location) + +authorize_url = "/oauth/authorize?response_type=token&client_id=client-id" + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(ImplicitGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/authorized"], + "scope": "profile", + "response_types": ["token"], + "grant_types": ["implicit"], + "token_endpoint_auth_method": "none", + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_get_authorize(test_client): + rv = test_client.get(authorize_url) + assert rv.data == b"ok" + + +def test_confidential_client(test_client, db, client): + client.client_secret = "client-secret" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/authorized"], + "scope": "profile", + "response_types": ["token"], + "grant_types": ["implicit"], + "token_endpoint_auth_method": "client_secret_basic", + } + ) + db.session.add(client) + db.session.commit() + + rv = test_client.get(authorize_url) + assert b"invalid_client" in rv.data + + +def test_unsupported_client(test_client, db, client): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/authorized"], + "scope": "profile", + "response_types": ["code"], + "grant_types": ["implicit"], + "token_endpoint_auth_method": "none", + } + ) + db.session.add(client) + db.session.commit() + rv = test_client.get(authorize_url) + assert "unauthorized_client" in rv.location + + +def test_invalid_authorize(test_client, server): + rv = test_client.post(authorize_url) + assert "#error=access_denied" in rv.location + + server.scopes_supported = ["profile"] + rv = test_client.post(authorize_url + "&scope=invalid") + assert "#error=invalid_scope" in rv.location + + +def test_authorize_token(test_client): + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "access_token=" in rv.location + + url = authorize_url + "&state=bar&scope=profile" + rv = test_client.post(url, data={"user_id": "1"}) + assert "access_token=" in rv.location + assert "state=bar" in rv.location + assert "scope=profile" in rv.location + + +def test_token_generator(test_client, app, server): + m = "tests.flask.test_oauth2.oauth2_server:token_generator" + app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + server.load_config(app.config) + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "access_token=c-implicit.1." in rv.location + + +def test_missing_scope_uses_default(test_client, client, monkeypatch): + """Per RFC 6749 Section 3.3, when scope is omitted, the server should use + a pre-defined default value from client.get_allowed_scope(). + """ + + def get_allowed_scope_with_default(scope): + if scope is None: + return "default_scope" + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_with_default) + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "access_token=" in rv.location + assert "scope=default_scope" in rv.location + + +def test_missing_scope_empty_default(test_client, client, monkeypatch): + """When client.get_allowed_scope() returns empty string for missing scope, + the token should be issued without a scope. + """ + + def get_allowed_scope_empty(scope): + if scope is None: + return "" + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_empty) + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "access_token=" in rv.location + assert "scope=" not in rv.location + + +def test_missing_scope_rejected(test_client, client, monkeypatch): + """Per RFC 6749 Section 3.3, when scope is omitted and client.get_allowed_scope() + returns None, the authorization should fail with invalid_scope. + """ + + def get_allowed_scope_reject(scope): + if scope is None: + return None + return scope + + monkeypatch.setattr(client, "get_allowed_scope", get_allowed_scope_reject) + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert "#error=invalid_scope" in rv.location diff --git a/tests/flask/test_oauth2/test_introspection_endpoint.py b/tests/flask/test_oauth2/test_introspection_endpoint.py index 578a1c240..6ed1b9a7c 100644 --- a/tests/flask/test_oauth2/test_introspection_endpoint.py +++ b/tests/flask/test_oauth2/test_introspection_endpoint.py @@ -1,155 +1,157 @@ +import pytest from flask import json + from authlib.integrations.sqla_oauth2 import create_query_token_func from authlib.oauth2.rfc7662 import IntrospectionEndpoint -from .models import db, User, Client, Token -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from .models import Token +from .models import User +from .models import db +from .oauth2_server import create_basic_header query_token = create_query_token_func(db.session, Token) class MyIntrospectionEndpoint(IntrospectionEndpoint): - def query_token(self, token, token_type_hint, client): - return query_token(token, token_type_hint, client) + def check_permission(self, token, client, request): + return True + + def query_token(self, token, token_type_hint): + return query_token(token, token_type_hint) def introspect_token(self, token): - user = User.query.get(token.user_id) + user = db.session.get(User, token.user_id) return { - "active": not token.revoked, + "active": True, "client_id": token.client_id, "username": user.username, "scope": token.scope, "sub": user.get_user_id(), "aud": token.client_id, - "iss": "https://server.example.com/", - "exp": token.get_expires_at(), + "iss": "https://provider.test/", + "exp": token.issued_at + token.expires_in, "iat": token.issued_at, } -class IntrospectTokenTest(TestCase): - def prepare_data(self): - app = self.app - - server = create_authorization_server(app) - server.register_endpoint(MyIntrospectionEndpoint) - - @app.route('/oauth/introspect', methods=['POST']) - def introspect_token(): - return server.create_endpoint_response('introspection') - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='introspect-client', - client_secret='introspect-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://a.b/c'], - }) - db.session.add(client) - db.session.commit() - - def create_token(self): - token = Token( - user_id=1, - client_id='introspect-client', - token_type='bearer', - access_token='a1', - refresh_token='r1', - scope='profile', - expires_in=3600, - ) - db.session.add(token) - db.session.commit() - - def test_invalid_client(self): - self.prepare_data() - rv = self.client.post('/oauth/introspect') - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = {'Authorization': 'invalid token_string'} - rv = self.client.post('/oauth/introspect', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'invalid-client', 'introspect-secret' - ) - rv = self.client.post('/oauth/introspect', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'introspect-client', 'invalid-secret' - ) - rv = self.client.post('/oauth/introspect', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_invalid_token(self): - self.prepare_data() - headers = self.create_basic_header( - 'introspect-client', 'introspect-secret' - ) - rv = self.client.post('/oauth/introspect', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/introspect', data={ - 'token_type_hint': 'refresh_token', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - 'token_type_hint': 'unsupported_token_type', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unsupported_token_type') - - rv = self.client.post('/oauth/introspect', data={ - 'token': 'invalid-token', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['active'], False) - - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - 'token_type_hint': 'refresh_token', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['active'], False) - - def test_introspect_token_with_hint(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'introspect-client', 'introspect-secret' - ) - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - 'token_type_hint': 'access_token', - }, headers=headers) - self.assertEqual(rv.status_code, 200) - resp = json.loads(rv.data) - self.assertEqual(resp['client_id'], 'introspect-client') - - def test_introspect_token_without_hint(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'introspect-client', 'introspect-secret' - ) - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - }, headers=headers) - self.assertEqual(rv.status_code, 200) - resp = json.loads(rv.data) - self.assertEqual(resp['client_id'], 'introspect-client') +@pytest.fixture(autouse=True) +def server(server, app): + server.register_endpoint(MyIntrospectionEndpoint) + + @app.route("/oauth/introspect", methods=["POST"]) + def introspect_token(): + return server.create_endpoint_response("introspection") + + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/callback"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_invalid_client(test_client): + rv = test_client.post("/oauth/introspect") + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = {"Authorization": "invalid token_string"} + rv = test_client.post("/oauth/introspect", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("invalid-client", "client-secret") + rv = test_client.post("/oauth/introspect", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("client-id", "invalid-secret") + rv = test_client.post("/oauth/introspect", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_invalid_token(test_client): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post("/oauth/introspect", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/introspect", + data={ + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "unsupported_token_type", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" + + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "invalid-token", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["active"] is False + + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["active"] is False + + +def test_introspect_token_with_hint(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "access_token", + }, + headers=headers, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" + + +def test_introspect_token_without_hint(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/introspect", + data={ + "token": "a1", + }, + headers=headers, + ) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp["client_id"] == "client-id" diff --git a/tests/flask/test_oauth2/test_jwt_authorization_request.py b/tests/flask/test_oauth2/test_jwt_authorization_request.py new file mode 100644 index 000000000..142e24faa --- /dev/null +++ b/tests/flask/test_oauth2/test_jwt_authorization_request.py @@ -0,0 +1,495 @@ +import json + +import pytest +from joserfc import jwk +from joserfc import jwt + +from authlib.common.urls import add_params_to_uri +from authlib.oauth2 import rfc7591 +from authlib.oauth2 import rfc9101 +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from tests.util import read_file_path + +from .models import Client +from .models import CodeGrantMixin +from .models import save_authorization_code + +authorize_url = "/oauth/authorize" + + +@pytest.fixture +def metadata(): + return {} + + +@pytest.fixture(autouse=True) +def server(server): + class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + "client_secret_basic", + "client_secret_post", + "none", + ] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + server.register_grant(AuthorizationCodeGrant) + return server + + +@pytest.fixture(autouse=True) +def client_registration_endpoint(app, server, metadata, db): + class ClientRegistrationEndpoint(rfc7591.ClientRegistrationEndpoint): + software_statement_alg_values_supported = ["RS256"] + + def authenticate_token(self, request): + auth_header = request.headers.get("Authorization") + request.user_id = 1 + return auth_header + + def resolve_public_key(self, request): + return read_file_path("rsa_public.pem") + + def save_client(self, client_info, client_metadata, request): + client = Client(user_id=request.user_id, **client_info) + client.set_client_metadata(client_metadata) + db.session.add(client) + db.session.commit() + return client + + def get_server_metadata(self): + return metadata + + server.register_endpoint( + ClientRegistrationEndpoint( + claims_classes=[ + rfc7591.ClientMetadataClaims, + rfc9101.ClientMetadataClaims, + ] + ) + ) + + @app.route("/create_client", methods=["POST"]) + def create_client(): + return server.create_endpoint_response(ClientRegistrationEndpoint.ENDPOINT_NAME) + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + "jwks": read_file_path("jwks_public.json"), + "require_signed_request_object": False, + } + ) + db.session.add(client) + db.session.commit() + return client + + +def register_request_object_extension( + server, + metadata=None, + request_object=None, + support_request=True, + support_request_uri=True, +): + class JWTAuthenticationRequest(rfc9101.JWTAuthenticationRequest): + def resolve_client_public_key(self, client): + return read_file_path("jwk_public.json") + + def get_request_object(self, request_uri: str): + return request_object + + def get_server_metadata(self): + return metadata or {} + + def get_client_require_signed_request_object(self, client): + return client.client_metadata.get("require_signed_request_object", False) + + server.register_extension( + JWTAuthenticationRequest( + support_request=support_request, support_request_uri=support_request_uri + ) + ) + + +def test_request_parameter_get(test_client, server): + """Pass the authentication payload in a JWT in the request query parameter.""" + register_request_object_extension(server) + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + url = add_params_to_uri( + authorize_url, {"client_id": "client-id", "request": request_obj} + ) + rv = test_client.get(url) + assert rv.data == b"ok" + + +def test_request_uri_parameter_get(test_client, server): + """Pass the authentication payload in a JWT in the request_uri query parameter.""" + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + register_request_object_extension(server, request_object=request_obj) + + url = add_params_to_uri( + authorize_url, + { + "client_id": "client-id", + "request_uri": "https://client.test/request_object", + }, + ) + rv = test_client.get(url) + assert rv.data == b"ok" + + +def test_request_and_request_uri_parameters(test_client, server): + """Passing both requests and request_uri parameters should return an error.""" + + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + register_request_object_extension(server, request_object=request_obj) + + url = add_params_to_uri( + authorize_url, + { + "client_id": "client-id", + "request": request_obj, + "request_uri": "https://client.test/request_object", + }, + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "The 'request' and 'request_uri' parameters are mutually exclusive." + ) + + +def test_neither_request_nor_request_uri_parameter(test_client, server): + """Passing parameters in the query string and not in a request object should still work.""" + + register_request_object_extension(server) + url = add_params_to_uri( + authorize_url, {"response_type": "code", "client_id": "client-id"} + ) + rv = test_client.get(url) + assert rv.data == b"ok" + + +def test_server_require_request_object(test_client, server, metadata): + """When server metadata 'require_signed_request_object' is true, request objects must be used.""" + metadata["require_signed_request_object"] = True + register_request_object_extension(server, metadata=metadata) + url = add_params_to_uri( + authorize_url, {"response_type": "code", "client_id": "client-id"} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "Authorization requests for this server must use signed request objects." + ) + + +def test_server_require_request_object_alg_none(test_client, server, metadata): + """When server metadata 'require_signed_request_object' is true, the JWT alg cannot be none.""" + + metadata["require_signed_request_object"] = True + register_request_object_extension(server, metadata=metadata) + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "none"}, + payload, + jwk.import_key(read_file_path("jwk_private.json")), + algorithms=["none"], + ) + url = add_params_to_uri( + authorize_url, {"client_id": "client-id", "request": request_obj} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "Authorization requests must be signed with supported algorithms." + ) + + +def test_client_require_signed_request_object(test_client, client, server, db): + """When client metadata 'require_signed_request_object' is true, request objects must be used.""" + + register_request_object_extension(server) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + "jwks": read_file_path("jwks_public.json"), + "require_signed_request_object": True, + } + ) + db.session.add(client) + db.session.commit() + + url = add_params_to_uri( + authorize_url, {"response_type": "code", "client_id": "client-id"} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "Authorization requests for this client must use signed request objects." + ) + + +def test_client_require_signed_request_object_alg_none(test_client, client, server, db): + """When client metadata 'require_signed_request_object' is true, the JWT alg cannot be none.""" + + register_request_object_extension(server) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "profile address", + "token_endpoint_auth_method": "client_secret_basic", + "response_types": ["code"], + "grant_types": ["authorization_code"], + "jwks": read_file_path("jwks_public.json"), + "require_signed_request_object": True, + } + ) + db.session.add(client) + db.session.commit() + + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "none"}, payload, jwk.generate_key("oct"), algorithms=["none"] + ) + url = add_params_to_uri( + authorize_url, {"client_id": "client-id", "request": request_obj} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "Authorization requests must be signed with supported algorithms." + ) + + +def test_unsupported_request_parameter(test_client, server): + """Passing the request parameter when unsupported should raise a 'request_not_supported' error.""" + + register_request_object_extension(server, support_request=False) + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + url = add_params_to_uri( + authorize_url, {"client_id": "client-id", "request": request_obj} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "request_not_supported" + assert ( + params["error_description"] + == "The authorization server does not support the use of the request parameter." + ) + + +def test_unsupported_request_uri_parameter(test_client, server): + """Passing the request parameter when unsupported should raise a 'request_uri_not_supported' error.""" + + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + register_request_object_extension( + server, request_object=request_obj, support_request_uri=False + ) + + url = add_params_to_uri( + authorize_url, + { + "client_id": "client-id", + "request_uri": "https://client.test/request_object", + }, + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "request_uri_not_supported" + assert ( + params["error_description"] + == "The authorization server does not support the use of the request_uri parameter." + ) + + +def test_invalid_request_uri_parameter(test_client, server): + """Invalid request_uri (or unreachable etc.) should raise a invalid_request_uri error.""" + + register_request_object_extension(server) + url = add_params_to_uri( + authorize_url, + { + "client_id": "client-id", + "request_uri": "https://client.test/request_object", + }, + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request_uri" + assert ( + params["error_description"] + == "The request_uri in the authorization request returns an error or contains invalid data." + ) + + +def test_invalid_request_object(test_client, server): + """Invalid request object should raise a invalid_request_object error.""" + + register_request_object_extension(server) + url = add_params_to_uri( + authorize_url, + { + "client_id": "client-id", + "request": "invalid", + }, + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request_object" + assert ( + params["error_description"] + == "The request parameter contains an invalid Request Object." + ) + + +def test_missing_client_id(test_client, server): + """The client_id parameter is mandatory.""" + + register_request_object_extension(server) + payload = {"response_type": "code", "client_id": "client-id"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + url = add_params_to_uri(authorize_url, {"request": request_obj}) + + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_client" + assert params["error_description"] == "Missing 'client_id' parameter." + + +def test_invalid_client_id(test_client, server): + """The client_id parameter is mandatory.""" + + register_request_object_extension(server) + payload = {"response_type": "code", "client_id": "invalid"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + url = add_params_to_uri( + authorize_url, {"client_id": "invalid", "request": request_obj} + ) + + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_client" + assert params["error_description"] == "The client does not exist on this server." + + +def test_different_client_id(test_client, server): + """The client_id parameter should be the same in the request payload and the request object.""" + + register_request_object_extension(server) + payload = {"response_type": "code", "client_id": "other-code-client"} + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + url = add_params_to_uri( + authorize_url, {"client_id": "client-id", "request": request_obj} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "The 'client_id' claim from the request parameters and the request object claims don't match." + ) + + +def test_request_param_in_request_object(test_client, server): + """The request and request_uri parameters should not be present in the request object.""" + + register_request_object_extension(server) + payload = { + "response_type": "code", + "client_id": "client-id", + "request_uri": "https://client.test/request_object", + } + request_obj = jwt.encode( + {"alg": "RS256"}, payload, jwk.import_key(read_file_path("jwk_private.json")) + ) + url = add_params_to_uri( + authorize_url, {"client_id": "client-id", "request": request_obj} + ) + rv = test_client.get(url) + params = json.loads(rv.data) + assert params["error"] == "invalid_request" + assert ( + params["error_description"] + == "The 'request' and 'request_uri' parameters must not be included in the request object." + ) + + +def test_registration(test_client, server): + """The 'require_signed_request_object' parameter should be available for client registration.""" + register_request_object_extension(server) + headers = {"Authorization": "bearer abc"} + + # Default case + body = { + "client_name": "Authlib", + } + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_name"] == "Authlib" + assert resp["require_signed_request_object"] is False + + # Nominal case + body = { + "require_signed_request_object": True, + "client_name": "Authlib", + } + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["client_name"] == "Authlib" + assert resp["require_signed_request_object"] is True + + # Error case + body = { + "require_signed_request_object": "invalid", + "client_name": "Authlib", + } + rv = test_client.post("/create_client", json=body, headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client_metadata" diff --git a/tests/flask/test_oauth2/test_jwt_bearer_client_auth.py b/tests/flask/test_oauth2/test_jwt_bearer_client_auth.py index 65e449913..3123a3f52 100644 --- a/tests/flask/test_oauth2/test_jwt_bearer_client_auth.py +++ b/tests/flask/test_oauth2/test_jwt_bearer_client_auth.py @@ -1,153 +1,344 @@ +import time + +import pytest from flask import json +from joserfc import jws +from joserfc import jwt +from joserfc.jwk import OctKey + from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant -from authlib.oauth2.rfc7523 import ( - JWTBearerClientAssertion, - client_secret_jwt_sign, - private_key_jwt_sign, -) +from authlib.oauth2.rfc7523 import JWTBearerClientAssertion +from authlib.oauth2.rfc7523 import client_secret_jwt_sign +from authlib.oauth2.rfc7523 import private_key_jwt_sign from tests.util import read_file_path -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class JWTClientCredentialsGrant(ClientCredentialsGrant): - TOKEN_ENDPOINT_AUTH_METHODS = [ - JWTBearerClientAssertion.CLIENT_AUTH_METHOD, - ] - - -class JWTClientAuth(JWTBearerClientAssertion): - def validate_jti(self, claims, jti): - return True - - def resolve_client_public_key(self, client, headers): - if headers['alg'] == 'RS256': - return read_file_path('jwk_public.json') - return client.client_secret - - -class ClientCredentialsTest(TestCase): - def prepare_data(self, auth_method, validate_jti=True): - server = create_authorization_server(self.app) - server.register_grant(JWTClientCredentialsGrant) - server.register_client_auth_method( - JWTClientAuth.CLIENT_AUTH_METHOD, - JWTClientAuth('https://localhost/oauth/token', validate_jti) - ) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='credential-client', - client_secret='credential-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - 'grant_types': ['client_credentials'], - 'token_endpoint_auth_method': auth_method, - }) - db.session.add(client) - db.session.commit() - - def test_invalid_client(self): - self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_invalid_jwt(self): - self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, - 'client_assertion': client_secret_jwt_sign( - client_secret='invalid-secret', - client_id='credential-client', - token_endpoint='https://localhost/oauth/token', - ) - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_not_found_client(self): - self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, - 'client_assertion': client_secret_jwt_sign( - client_secret='credential-secret', - client_id='invalid-client', - token_endpoint='https://localhost/oauth/token', - ) - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_not_supported_auth_method(self): - self.prepare_data('invalid') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, - 'client_assertion': client_secret_jwt_sign( - client_secret='credential-secret', - client_id='credential-client', - token_endpoint='https://localhost/oauth/token', - ) - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_client_secret_jwt(self): - self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, - 'client_assertion': client_secret_jwt_sign( - client_secret='credential-secret', - client_id='credential-client', - token_endpoint='https://localhost/oauth/token', - claims={'jti': 'nonce'}, - ) - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_private_key_jwt(self): - self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, - 'client_assertion': private_key_jwt_sign( - private_key=read_file_path('jwk_private.json'), - client_id='credential-client', - token_endpoint='https://localhost/oauth/token', - ) - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_not_validate_jti(self): - self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD, False) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, - 'client_assertion': client_secret_jwt_sign( - client_secret='credential-secret', - client_id='credential-client', - token_endpoint='https://localhost/oauth/token', - ) - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) + + +@pytest.fixture(autouse=True) +def server(server): + class JWTClientCredentialsGrant(ClientCredentialsGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + JWTBearerClientAssertion.CLIENT_AUTH_METHOD, + ] + + server.register_grant(JWTClientCredentialsGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": JWTBearerClientAssertion.CLIENT_AUTH_METHOD, + } + ) + db.session.add(client) + db.session.commit() + return client + + +def register_jwt_client_auth(server, validate_jti=True): + class JWTClientAuth(JWTBearerClientAssertion): + def get_audiences(self): + return ["https://provider.test/oauth/token"] + + def validate_jti(self, claims, jti): + return jti != "used" + + def resolve_client_public_key(self, client, headers): + if headers["alg"] == "RS256": + return read_file_path("jwk_public.json") + return client.client_secret + + server.register_client_auth_method( + JWTClientAuth.CLIENT_AUTH_METHOD, + JWTClientAuth(validate_jti=validate_jti), + ) + + +def test_invalid_client(test_client, server): + register_jwt_client_auth(server) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_invalid_jwt(test_client, server): + register_jwt_client_auth(server) + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_secret_jwt_sign( + client_secret="invalid-secret", + client_id="client-id", + token_endpoint="https://provider.test/oauth/token", + ), + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_not_found_client(test_client, server): + register_jwt_client_auth(server) + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_secret_jwt_sign( + client_secret="client-secret", + client_id="invalid-client", + token_endpoint="https://provider.test/oauth/token", + ), + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_not_supported_auth_method(test_client, server, client, db): + register_jwt_client_auth(server) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": "invalid", + } + ) + db.session.add(client) + db.session.commit() + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_secret_jwt_sign( + client_secret="client-secret", + client_id="client-id", + token_endpoint="https://provider.test/oauth/token", + ), + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_client_secret_jwt(test_client, server): + register_jwt_client_auth(server) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_secret_jwt_sign( + client_secret="client-secret", + client_id="client-id", + token_endpoint="https://provider.test/oauth/token", + claims={"jti": "nonce"}, + ), + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_private_key_jwt(test_client, server): + register_jwt_client_auth(server) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": private_key_jwt_sign( + private_key=read_file_path("jwk_private.json"), + client_id="client-id", + token_endpoint="https://provider.test/oauth/token", + ), + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_not_validate_jti(test_client, server): + register_jwt_client_auth(server, validate_jti=False) + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_secret_jwt_sign( + client_secret="client-secret", + client_id="client-id", + token_endpoint="https://provider.test/oauth/token", + ), + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_validate_jti_failed(test_client, server): + register_jwt_client_auth(server) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_secret_jwt_sign( + client_secret="client-secret", + client_id="client-id", + token_endpoint="https://provider.test/oauth/token", + claims={"jti": "used"}, + ), + }, + ) + resp = json.loads(rv.data) + assert "JWT ID" in resp["error_description"] + + +def test_invalid_assertion(test_client, server): + register_jwt_client_auth(server) + client_assertion = jws.serialize_compact( + {"alg": "HS256"}, + "text", + OctKey.import_key("client-secret"), + ) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_assertion, + }, + ) + resp = json.loads(rv.data) + assert "Invalid JWT" in resp["error_description"] + + +def test_missing_exp_claim(test_client, server): + register_jwt_client_auth(server) + key = OctKey.import_key("client-secret") + # missing "exp" value + claims = { + "iss": "client-id", + "sub": "client-id", + "aud": "https://provider.test/oauth/token", + "jti": "nonce", + } + client_assertion = jwt.encode({"alg": "HS256"}, claims, key) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_assertion, + }, + ) + resp = json.loads(rv.data) + assert "error" in resp + assert "'exp'" in resp["error_description"] + + +def test_iss_sub_not_same(test_client, server): + register_jwt_client_auth(server) + key = OctKey.import_key("client-secret") + # missing "exp" value + claims = { + "sub": "client-id", + "iss": "invalid-iss", + "aud": "https://provider.test/oauth/token", + "exp": int(time.time() + 3600), + "jti": "nonce", + } + client_assertion = jwt.encode({"alg": "HS256"}, claims, key) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_assertion, + }, + ) + resp = json.loads(rv.data) + assert "error" in resp + assert resp["error_description"] == "Issuer and Subject MUST match." + + +def test_missing_jti(test_client, server): + register_jwt_client_auth(server) + key = OctKey.import_key("client-secret") + # missing "exp" value + claims = { + "sub": "client-id", + "iss": "client-id", + "aud": "https://provider.test/oauth/token", + "exp": int(time.time() + 3600), + } + client_assertion = jwt.encode({"alg": "HS256"}, claims, key) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_assertion, + }, + ) + resp = json.loads(rv.data) + assert "error" in resp + assert resp["error_description"] == "Missing JWT ID." + + +def test_issuer_as_audience(test_client, server): + """Per RFC 7523 Section 3 and draft-ietf-oauth-rfc7523bis, the AS issuer + identifier should be a valid audience value for client assertion JWTs.""" + + class JWTClientAuth(JWTBearerClientAssertion): + def get_audiences(self): + return ["https://provider.test/oauth/token", "https://provider.test"] + + def validate_jti(self, claims, jti): + return True + + def resolve_client_public_key(self, client, headers): + return client.client_secret + + server.register_client_auth_method( + JWTClientAuth.CLIENT_AUTH_METHOD, + JWTClientAuth(), + ) + + key = OctKey.import_key("client-secret") + claims = { + "iss": "client-id", + "sub": "client-id", + "aud": "https://provider.test", + "exp": int(time.time() + 3600), + "jti": "nonce", + } + client_assertion = jwt.encode({"alg": "HS256"}, claims, key) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "client_assertion_type": JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + "client_assertion": client_assertion, + }, + ) + resp = json.loads(rv.data) + assert "access_token" in resp diff --git a/tests/flask/test_oauth2/test_jwt_bearer_grant.py b/tests/flask/test_oauth2/test_jwt_bearer_grant.py index 41ca77e96..95ed243b4 100644 --- a/tests/flask/test_oauth2/test_jwt_bearer_grant.py +++ b/tests/flask/test_oauth2/test_jwt_bearer_grant.py @@ -1,106 +1,214 @@ +import pytest from flask import json +from joserfc import jws +from joserfc import jwt +from joserfc.jwk import KeySet +from joserfc.jwk import OctKey + from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oauth2.rfc7523 import JWTBearerTokenGenerator +from tests.util import read_file_path + +from .models import Client +from .models import db class JWTBearerGrant(_JWTBearerGrant): - def authenticate_user(self, client, claims): - return None + def resolve_issuer_client(self, issuer): + return Client.query.filter_by(client_id=issuer).first() - def authenticate_client(self, claims): - iss = claims['iss'] - return Client.query.filter_by(client_id=iss).first() - - def resolve_public_key(self, headers, payload): - keys = {'1': 'foo', '2': 'bar'} - return keys[headers['kid']] - - -class JWTBearerGrantTest(TestCase): - def prepare_data(self, grant_type=None): - server = create_authorization_server(self.app) - server.register_grant(JWTBearerGrant) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - if grant_type is None: - grant_type = JWTBearerGrant.GRANT_TYPE - client = Client( - user_id=user.id, - client_id='jwt-client', - client_secret='jwt-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - 'grant_types': [grant_type], - }) - db.session.add(client) - db.session.commit() - - def test_missing_assertion(self): - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': JWTBearerGrant.GRANT_TYPE - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - self.assertIn('assertion', resp['error_description']) - - def test_invalid_assertion(self): - self.prepare_data() - assertion = JWTBearerGrant.sign( - 'foo', issuer='jwt-client', audience='https://i.b/token', - header={'alg': 'HS256', 'kid': '1'} - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': JWTBearerGrant.GRANT_TYPE, - 'assertion': assertion - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_grant') - - def test_authorize_token(self): - self.prepare_data() - assertion = JWTBearerGrant.sign( - 'foo', issuer='jwt-client', audience='https://i.b/token', - subject='self', header={'alg': 'HS256', 'kid': '1'} + def resolve_client_public_key(self, client): + return KeySet( + [ + OctKey.import_key("foo", {"kid": "1"}), + OctKey.import_key("bar", {"kid": "2"}), + ] ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': JWTBearerGrant.GRANT_TYPE, - 'assertion': assertion - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_unauthorized_client(self): - self.prepare_data('password') - assertion = JWTBearerGrant.sign( - 'bar', issuer='jwt-client', audience='https://i.b/token', - subject='self', header={'alg': 'HS256', 'kid': '2'} - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': JWTBearerGrant.GRANT_TYPE, - 'assertion': assertion - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - def test_token_generator(self): - m = 'tests.flask.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - self.prepare_data() - assertion = JWTBearerGrant.sign( - 'foo', issuer='jwt-client', audience='https://i.b/token', - subject='self', header={'alg': 'HS256', 'kid': '1'} - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': JWTBearerGrant.GRANT_TYPE, - 'assertion': assertion - }) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('j-', resp['access_token']) + + def authenticate_user(self, subject): + return None + + def has_granted_permission(self, client, user): + return True + + +@pytest.fixture(autouse=True) +def server(server): + server.register_grant(JWTBearerGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "grant_types": [JWTBearerGrant.GRANT_TYPE], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_missing_assertion(test_client): + rv = test_client.post( + "/oauth/token", data={"grant_type": JWTBearerGrant.GRANT_TYPE} + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + assert "assertion" in resp["error_description"] + + +def test_invalid_assertion(test_client): + assertion = JWTBearerGrant.sign( + "foo", + issuer="client-id", + audience="https://provider.test/token", + subject="none", + header={"alg": "HS256", "kid": "1"}, + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" + + +def test_authorize_token(test_client): + assertion = JWTBearerGrant.sign( + "foo", + issuer="client-id", + audience="https://provider.test/token", + subject=None, + header={"alg": "HS256", "kid": "1"}, + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_unauthorized_client(test_client, client): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "grant_types": ["password"], + } + ) + db.session.add(client) + db.session.commit() + + assertion = JWTBearerGrant.sign( + "bar", + issuer="client-id", + audience="https://provider.test/token", + subject=None, + header={"alg": "HS256", "kid": "2"}, + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unauthorized_client" + + +def test_token_generator(test_client, app, server): + m = "tests.flask.test_oauth2.oauth2_server:token_generator" + app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + server.load_config(app.config) + assertion = JWTBearerGrant.sign( + "foo", + issuer="client-id", + audience="https://provider.test/token", + subject=None, + header={"alg": "HS256", "kid": "1"}, + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "c-" in resp["access_token"] + + +def test_jwt_bearer_token_generator(test_client, server): + private_key = read_file_path("jwks_private.json") + server.register_token_generator( + JWTBearerGrant.GRANT_TYPE, + JWTBearerTokenGenerator(private_key), + ) + assertion = JWTBearerGrant.sign( + "foo", + issuer="client-id", + audience="https://provider.test/token", + subject=None, + header={"alg": "HS256", "kid": "1"}, + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert resp["access_token"].count(".") == 2 + + +def test_invalid_payload_assertion(test_client): + assertion = jws.serialize_compact( + {"alg": "HS256", "kid": "1"}, + "text", + OctKey.import_key("foo"), + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert "Invalid JWT" in resp["error_description"] + + +def test_missing_assertion_claims(test_client): + assertion = jwt.encode( + {"alg": "HS256", "kid": "1"}, + {"iss": "client-id"}, + OctKey.import_key("foo"), + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": JWTBearerGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert "Missing claim" in resp["error_description"] + + +def test_invalid_audience(test_client, server): + """RFC 7523 Section 3: The authorization server MUST reject any JWT that + does not contain its own identity as the intended audience.""" + + class StrictAudienceGrant(JWTBearerGrant): + def get_audiences(self): + return ["https://provider.test/token"] + + server._token_grants.clear() + server.register_grant(StrictAudienceGrant) + assertion = StrictAudienceGrant.sign( + "foo", + issuer="client-id", + audience="https://evil.test/token", + subject=None, + header={"alg": "HS256", "kid": "1"}, + ) + rv = test_client.post( + "/oauth/token", + data={"grant_type": StrictAudienceGrant.GRANT_TYPE, "assertion": assertion}, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" diff --git a/tests/flask/test_oauth2/test_oauth2_server.py b/tests/flask/test_oauth2/test_oauth2_server.py index 0e16d9c10..e5b667195 100644 --- a/tests/flask/test_oauth2/test_oauth2_server.py +++ b/tests/flask/test_oauth2/test_oauth2_server.py @@ -1,206 +1,177 @@ -from flask import json, jsonify -from authlib.integrations.flask_oauth2 import ResourceProtector, current_token +import pytest +from flask import json +from flask import jsonify + +from authlib.integrations.flask_oauth2 import ResourceProtector +from authlib.integrations.flask_oauth2 import current_token from authlib.integrations.sqla_oauth2 import create_bearer_token_validator -from .models import db, User, Client, Token -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server -require_oauth = ResourceProtector() -BearerTokenValidator = create_bearer_token_validator(db.session, Token) -require_oauth.register_token_validator(BearerTokenValidator()) +from .models import Token +from .oauth2_server import create_bearer_header + + +@pytest.fixture(autouse=True) +def server(server): + return server -def create_resource_server(app): - @app.route('/user') - @require_oauth('profile') +@pytest.fixture(autouse=True) +def resource_server(app, db): + require_oauth = ResourceProtector() + BearerTokenValidator = create_bearer_token_validator(db.session, Token) + require_oauth.register_token_validator(BearerTokenValidator()) + + @app.route("/user") + @require_oauth("profile") def user_profile(): user = current_token.user return jsonify(id=user.id, username=user.username) - @app.route('/user/email') - @require_oauth('email') + @app.route("/user/email") + @require_oauth("email") def user_email(): user = current_token.user - return jsonify(email=user.username + '@example.com') + return jsonify(email=user.username + "@example.com") - @app.route('/info') + @app.route("/info") @require_oauth() def public_info(): - return jsonify(status='ok') + return jsonify(status="ok") + + @app.route("/no-parens") + @require_oauth + def no_parens(): + return jsonify(status="ok") - @app.route('/operator-and') - @require_oauth('profile email', 'AND') + @app.route("/operator-and") + @require_oauth(["profile email"]) def operator_and(): - return jsonify(status='ok') + return jsonify(status="ok") - @app.route('/operator-or') - @require_oauth('profile email', 'OR') + @app.route("/operator-or") + @require_oauth(["profile", "email"]) def operator_or(): - return jsonify(status='ok') + return jsonify(status="ok") - def scope_operator(token_scopes, resource_scopes): - return 'profile' in token_scopes and 'email' not in token_scopes - - @app.route('/operator-func') - @require_oauth(operator=scope_operator) - def operator_func(): - return jsonify(status='ok') - - @app.route('/acquire') + @app.route("/acquire") def test_acquire(): - with require_oauth.acquire('profile') as token: + with require_oauth.acquire("profile") as token: user = token.user return jsonify(id=user.id, username=user.username) - @app.route('/optional') - @require_oauth('profile', optional=True) + @app.route("/optional") + @require_oauth("profile", optional=True) def test_optional_token(): if current_token: user = current_token.user return jsonify(id=user.id, username=user.username) else: - return jsonify(id=0, username='anonymous') - - -class AuthorizationTest(TestCase): - def test_none_grant(self): - create_authorization_server(self.app) - authorize_url = ( - '/oauth/authorize?response_type=token' - '&client_id=implicit-client' - ) - rv = self.client.get(authorize_url) - self.assertIn(b'invalid_grant', rv.data) - - rv = self.client.post(authorize_url, data={'user_id': '1'}) - self.assertNotEqual(rv.status, 200) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'x', - }) - data = json.loads(rv.data) - self.assertEqual(data['error'], 'unsupported_grant_type') - - -class ResourceTest(TestCase): - def prepare_data(self): - create_resource_server(self.app) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='resource-client', - client_secret='resource-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - }) - db.session.add(client) - db.session.commit() - - def create_token(self, expires_in=3600): - token = Token( - user_id=1, - client_id='resource-client', - token_type='bearer', - access_token='a1', - scope='profile', - expires_in=expires_in, - ) - db.session.add(token) - db.session.commit() - - def create_bearer_header(self, token): - return {'Authorization': 'Bearer ' + token} - - def test_invalid_token(self): - self.prepare_data() - - rv = self.client.get('/user') - self.assertEqual(rv.status_code, 401) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'missing_authorization') - - headers = {'Authorization': 'invalid token'} - rv = self.client.get('/user', headers=headers) - self.assertEqual(rv.status_code, 401) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unsupported_token_type') - - headers = self.create_bearer_header('invalid') - rv = self.client.get('/user', headers=headers) - self.assertEqual(rv.status_code, 401) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_token') - - def test_expired_token(self): - self.prepare_data() - self.create_token(0) - headers = self.create_bearer_header('a1') - - rv = self.client.get('/user', headers=headers) - self.assertEqual(rv.status_code, 401) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_token') - - rv = self.client.get('/acquire', headers=headers) - self.assertEqual(rv.status_code, 401) - - def test_insufficient_token(self): - self.prepare_data() - self.create_token() - headers = self.create_bearer_header('a1') - rv = self.client.get('/user/email', headers=headers) - self.assertEqual(rv.status_code, 403) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'insufficient_scope') - - def test_access_resource(self): - self.prepare_data() - self.create_token() - headers = self.create_bearer_header('a1') - - rv = self.client.get('/user', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['username'], 'foo') - - rv = self.client.get('/acquire', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['username'], 'foo') - - rv = self.client.get('/info', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['status'], 'ok') - - def test_scope_operator(self): - self.prepare_data() - self.create_token() - headers = self.create_bearer_header('a1') - rv = self.client.get('/operator-and', headers=headers) - self.assertEqual(rv.status_code, 403) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'insufficient_scope') - - rv = self.client.get('/operator-or', headers=headers) - self.assertEqual(rv.status_code, 200) - - rv = self.client.get('/operator-func', headers=headers) - self.assertEqual(rv.status_code, 200) - - def test_optional_token(self): - self.prepare_data() - rv = self.client.get('/optional') - self.assertEqual(rv.status_code, 200) - resp = json.loads(rv.data) - self.assertEqual(resp['username'], 'anonymous') - - self.create_token() - headers = self.create_bearer_header('a1') - rv = self.client.get('/optional', headers=headers) - self.assertEqual(rv.status_code, 200) - resp = json.loads(rv.data) - self.assertEqual(resp['username'], 'foo') + return jsonify(id=0, username="anonymous") + + return require_oauth + + +def test_authorization_none_grant(test_client): + authorize_url = "/oauth/authorize?response_type=token&client_id=implicit-client" + rv = test_client.get(authorize_url) + assert b"unsupported_response_type" in rv.data + + rv = test_client.post(authorize_url, data={"user_id": "1"}) + assert rv.status != 200 + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "x", + }, + ) + data = json.loads(rv.data) + assert data["error"] == "unsupported_grant_type" + + +def test_invalid_token(test_client, token): + rv = test_client.get("/user") + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "missing_authorization" + + headers = {"Authorization": "invalid token"} + rv = test_client.get("/user", headers=headers) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" + + headers = create_bearer_header("invalid") + rv = test_client.get("/user", headers=headers) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + +def test_expired_token(test_client, db, token): + token.expires_in = -10 + db.session.add(token) + db.session.commit() + + headers = create_bearer_header("a1") + + rv = test_client.get("/user", headers=headers) + assert rv.status_code == 401 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_token" + + rv = test_client.get("/acquire", headers=headers) + assert rv.status_code == 401 + + +def test_insufficient_token(test_client, token): + headers = create_bearer_header("a1") + rv = test_client.get("/user/email", headers=headers) + assert rv.status_code == 403 + resp = json.loads(rv.data) + assert resp["error"] == "insufficient_scope" + + +def test_access_resource(test_client, token): + headers = create_bearer_header("a1") + + rv = test_client.get("/user", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + rv = test_client.get("/acquire", headers=headers) + resp = json.loads(rv.data) + assert resp["username"] == "foo" + + rv = test_client.get("/info", headers=headers) + resp = json.loads(rv.data) + assert resp["status"] == "ok" + + rv = test_client.get("/no-parens", headers=headers) + resp = json.loads(rv.data) + assert resp["status"] == "ok" + + +def test_scope_operator(test_client, token): + headers = create_bearer_header("a1") + rv = test_client.get("/operator-and", headers=headers) + assert rv.status_code == 403 + resp = json.loads(rv.data) + assert resp["error"] == "insufficient_scope" + + rv = test_client.get("/operator-or", headers=headers) + assert rv.status_code == 200 + + +def test_optional_token(test_client, token): + rv = test_client.get("/optional") + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp["username"] == "anonymous" + + headers = create_bearer_header("a1") + rv = test_client.get("/optional", headers=headers) + assert rv.status_code == 200 + resp = json.loads(rv.data) + assert resp["username"] == "foo" diff --git a/tests/flask/test_oauth2/test_openid_code_grant.py b/tests/flask/test_oauth2/test_openid_code_grant.py index 9b7601bd3..561be27d9 100644 --- a/tests/flask/test_oauth2/test_openid_code_grant.py +++ b/tests/flask/test_oauth2/test_openid_code_grant.py @@ -1,260 +1,499 @@ +import time + +import pytest +from flask import current_app from flask import json -from authlib.common.urls import urlparse, url_decode, url_encode -from authlib.jose import JsonWebToken, JsonWebKey -from authlib.oidc.core import CodeIDToken -from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from joserfc import jwt +from joserfc.jwk import ECKey +from joserfc.jwk import KeySet +from joserfc.jwk import OctKey +from joserfc.jwk import RSAKey + +from authlib.common.urls import url_decode +from authlib.common.urls import url_encode +from authlib.common.urls import urlparse from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ) -from tests.util import get_file_path -from .models import db, User, Client, exists_nonce -from .models import CodeGrantMixin, save_authorization_code -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): - def save_authorization_code(self, code, request): - return save_authorization_code(code, request) - - -class OpenIDCode(_OpenIDCode): - def get_jwt_config(self, grant): - config = grant.server.config - key = config['jwt_key'] - alg = config['jwt_alg'] - iss = config['jwt_iss'] - exp = config['jwt_exp'] - return dict(key=key, alg=alg, iss=iss, exp=exp) - - def exists_nonce(self, nonce, request): - return exists_nonce(nonce, request) - - def generate_user_info(self, user, scopes): - return user.generate_user_info(scopes) - - -class BaseTestCase(TestCase): - def config_app(self): - self.app.config.update({ - 'OAUTH2_JWT_ENABLED': True, - 'OAUTH2_JWT_ISS': 'Authlib', - 'OAUTH2_JWT_KEY': 'secret', - 'OAUTH2_JWT_ALG': 'HS256', - }) - - def prepare_data(self): - self.config_app() - server = create_authorization_server(self.app) - server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - - client = Client( - user_id=user.id, - client_id='code-client', - client_secret='code-secret', - ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b'], - 'scope': 'openid profile address', - 'response_types': ['code'], - 'grant_types': ['authorization_code'], - }) - db.session.add(client) - db.session.commit() - - -class OpenIDCodeTest(BaseTestCase): - def test_authorize_token(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'code', - 'client_id': 'code-client', - 'state': 'bar', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1' - }) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) - - jwt = JsonWebToken() - claims = jwt.decode( - resp['id_token'], 'secret', - claims_cls=CodeIDToken, - claims_options={'iss': {'value': 'Authlib'}} - ) - claims.validate() - - def test_pure_code_flow(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'code', - 'client_id': 'code-client', - 'state': 'bar', - 'scope': 'profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1' - }) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertNotIn('id_token', resp) - - def test_nonce_replay(self): - self.prepare_data() - data = { - 'response_type': 'code', - 'client_id': 'code-client', - 'user_id': '1', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b' +from authlib.oidc.core import CodeIDToken +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from tests.util import read_file_path + +from .models import Client +from .models import CodeGrantMixin +from .models import exists_nonce +from .models import save_authorization_code +from .oauth2_server import create_basic_header + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "openid profile address", + "response_types": ["code"], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +@pytest.fixture(autouse=True) +def server(server, app): + app.config.update( + { + "OAUTH2_JWT_ISS": "Authlib", + "OAUTH2_JWT_KEY": "secret", + "OAUTH2_JWT_ALG": "HS256", + } + ) + return server + + +def register_oidc_code_grant(server, require_nonce=False): + class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + class OpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant, client): + key = current_app.config.get("OAUTH2_JWT_KEY") + alg = current_app.config.get("OAUTH2_JWT_ALG") + iss = current_app.config.get("OAUTH2_JWT_ISS") + return dict(key=key, alg=alg, iss=iss, exp=3600) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + server.register_grant( + AuthorizationCodeGrant, [OpenIDCode(require_nonce=require_nonce)] + ) + + +def test_authorize_token(test_client, server): + register_oidc_code_grant( + server, + ) + auth_request_time = time.time() + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "client-id", + "state": "bar", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["state"] == "bar" + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" in resp + + token = jwt.decode(resp["id_token"], key=OctKey.import_key("secret")) + claims = CodeIDToken( + token.claims, + token.header, + {"iss": {"value": "Authlib"}}, + ) + claims.validate() + assert claims["auth_time"] >= int(auth_request_time) + assert claims["acr"] == "urn:mace:incommon:iap:silver" + assert claims["amr"] == ["pwd", "otp"] + + +def test_pure_code_flow(test_client, server): + register_oidc_code_grant( + server, + ) + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "client-id", + "state": "bar", + "scope": "profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["state"] == "bar" + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" not in resp + + +def test_require_nonce(test_client, server): + register_oidc_code_grant(server, require_nonce=True) + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "client-id", + "user_id": "1", + "state": "bar", + "scope": "openid profile", + "redirect_uri": "https://client.test", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["error"] == "invalid_request" + assert params["error_description"] == "Missing 'nonce' in request." + + +def test_nonce_replay(test_client, server): + register_oidc_code_grant( + server, + ) + data = { + "response_type": "code", + "client_id": "client-id", + "user_id": "1", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + } + rv = test_client.post("/oauth/authorize", data=data) + assert "code=" in rv.location + + rv = test_client.post("/oauth/authorize", data=data) + assert "error=" in rv.location + + +def test_prompt(test_client, server): + register_oidc_code_grant( + server, + ) + params = [ + ("response_type", "code"), + ("client_id", "client-id"), + ("state", "bar"), + ("nonce", "abc"), + ("scope", "openid profile"), + ("redirect_uri", "https://client.test"), + ] + query = url_encode(params) + rv = test_client.get("/oauth/authorize?" + query) + assert rv.data == b"login" + + query = url_encode(params + [("user_id", "1")]) + rv = test_client.get("/oauth/authorize?" + query) + assert rv.data == b"ok" + + query = url_encode(params + [("prompt", "login")]) + rv = test_client.get("/oauth/authorize?" + query) + assert rv.data == b"login" + + query = url_encode(params + [("user_id", "1"), ("prompt", "login")]) + rv = test_client.get("/oauth/authorize?" + query) + assert rv.data == b"login" + + +def test_prompt_none_not_logged(test_client, server): + register_oidc_code_grant( + server, + ) + params = [ + ("response_type", "code"), + ("client_id", "client-id"), + ("state", "bar"), + ("nonce", "abc"), + ("scope", "openid profile"), + ("redirect_uri", "https://client.test"), + ("prompt", "none"), + ] + query = url_encode(params) + rv = test_client.get("/oauth/authorize?" + query) + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["error"] == "login_required" + assert params["state"] == "bar" + + +def test_client_metadata_custom_alg(test_client, server, client, db, app): + """If the client metadata 'id_token_signed_response_alg' is defined, + it should be used to sign id_tokens.""" + register_oidc_code_grant( + server, + ) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "openid profile address", + "response_types": ["code"], + "grant_types": ["authorization_code"], + "id_token_signed_response_alg": "HS384", + } + ) + db.session.add(client) + db.session.commit() + del app.config["OAUTH2_JWT_ALG"] + + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "client-id", + "state": "bar", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + token = jwt.decode( + resp["id_token"], + key=OctKey.import_key("secret"), + algorithms=["HS384"], + ) + claims = CodeIDToken( + token.claims, + token.header, + {"iss": {"value": "Authlib"}}, + ) + claims.validate() + assert claims.header["alg"] == "HS384" + + +def test_client_metadata_alg_none(test_client, server, app, db, client): + """The 'none' 'id_token_signed_response_alg' alg should be + supported in non implicit flows.""" + register_oidc_code_grant( + server, + ) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "openid profile address", + "response_types": ["code"], + "grant_types": ["authorization_code"], + "id_token_signed_response_alg": "none", + } + ) + db.session.add(client) + db.session.commit() + + del app.config["OAUTH2_JWT_ALG"] + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "client-id", + "state": "bar", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + token = jwt.decode( + resp["id_token"], + key=OctKey.import_key("secret"), + algorithms=["none"], + ) + claims = CodeIDToken( + token.claims, + token.header, + {"iss": {"value": "Authlib"}}, + ) + claims.validate() + assert claims.header["alg"] == "none" + + +@pytest.mark.parametrize( + "alg, private_key, public_key", + [ + ( + "RS256", + RSAKey.import_key(read_file_path("jwk_private.json")), + RSAKey.import_key(read_file_path("jwk_public.json")), + ), + ( + "PS256", + KeySet.import_key_set(read_file_path("jwks_private.json")), + KeySet.import_key_set(read_file_path("jwks_public.json")), + ), + ( + "ES512", + ECKey.import_key(read_file_path("secp521r1-private.json")), + ECKey.import_key(read_file_path("secp521r1-public.json")), + ), + ( + "RS256", + RSAKey.import_key(read_file_path("rsa_private.pem")), + RSAKey.import_key(read_file_path("rsa_public.pem")), + ), + ], +) +def test_authorize_token_algs(test_client, server, app, alg, private_key, public_key): + # generate refresh token + app.config["OAUTH2_JWT_KEY"] = private_key + app.config["OAUTH2_JWT_ALG"] = alg + register_oidc_code_grant( + server, + ) + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "client-id", + "state": "bar", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["state"] == "bar" + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" in resp + + token = jwt.decode( + resp["id_token"], + key=public_key, + algorithms=[alg], + ) + claims = CodeIDToken( + token.claims, + token.header, + {"iss": {"value": "Authlib"}}, + ) + claims.validate() + + +def test_deprecated_get_jwt_config_signature(test_client, server, db, user): + """Using the old get_jwt_config(self, grant) signature should emit a DeprecationWarning.""" + + class DeprecatedOpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant): + return {"key": "secret", "alg": "HS256", "iss": "Authlib", "exp": 3600} + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + server.register_grant(AuthorizationCodeGrant, [DeprecatedOpenIDCode()]) + + client = Client( + user_id=user.id, + client_id="deprecated-client", + client_secret="secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "openid profile", + "response_types": ["code"], + "grant_types": ["authorization_code"], } - rv = self.client.post('/oauth/authorize', data=data) - self.assertIn('code=', rv.location) - - rv = self.client.post('/oauth/authorize', data=data) - self.assertIn('error=', rv.location) - - def test_prompt(self): - self.prepare_data() - params = [ - ('response_type', 'code'), - ('client_id', 'code-client'), - ('state', 'bar'), - ('nonce', 'abc'), - ('scope', 'openid profile'), - ('redirect_uri', 'https://a.b') - ] - query = url_encode(params) - rv = self.client.get('/oauth/authorize?' + query) - self.assertEqual(rv.data, b'login') - - query = url_encode(params + [('user_id', '1')]) - rv = self.client.get('/oauth/authorize?' + query) - self.assertEqual(rv.data, b'ok') - - query = url_encode(params + [('prompt', 'login')]) - rv = self.client.get('/oauth/authorize?' + query) - self.assertEqual(rv.data, b'login') - - -class RSAOpenIDCodeTest(BaseTestCase): - def config_app(self): - self.app.config.update({ - 'OAUTH2_JWT_ENABLED': True, - 'OAUTH2_JWT_ISS': 'Authlib', - 'OAUTH2_JWT_KEY_PATH': get_file_path('jwk_private.json'), - 'OAUTH2_JWT_ALG': 'RS256', - }) - - def get_validate_key(self): - with open(get_file_path('jwk_public.json'), 'r') as f: - return json.load(f) - - def test_authorize_token(self): - # generate refresh token - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'code', - 'client_id': 'code-client', - 'state': 'bar', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1' - }) - self.assertIn('code=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) - - jwt = JsonWebToken() - claims = jwt.decode( - resp['id_token'], - self.get_validate_key(), - claims_cls=CodeIDToken, - claims_options={'iss': {'value': 'Authlib'}} + ) + db.session.add(client) + db.session.commit() + + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code", + "client_id": "deprecated-client", + "state": "bar", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + code = params["code"] + + with pytest.warns(DeprecationWarning, match="get_jwt_config.*version 1.8"): + test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=create_basic_header("deprecated-client", "secret"), ) - claims.validate() - - -class JWKSOpenIDCodeTest(RSAOpenIDCodeTest): - def config_app(self): - self.app.config.update({ - 'OAUTH2_JWT_ENABLED': True, - 'OAUTH2_JWT_ISS': 'Authlib', - 'OAUTH2_JWT_KEY_PATH': get_file_path('jwks_private.json'), - 'OAUTH2_JWT_ALG': 'PS256', - }) - - def get_validate_key(self): - with open(get_file_path('jwks_public.json'), 'r') as f: - return JsonWebKey.import_key_set(json.load(f)) - - -class ECOpenIDCodeTest(RSAOpenIDCodeTest): - def config_app(self): - self.app.config.update({ - 'OAUTH2_JWT_ENABLED': True, - 'OAUTH2_JWT_ISS': 'Authlib', - 'OAUTH2_JWT_KEY_PATH': get_file_path('ec_private.json'), - 'OAUTH2_JWT_ALG': 'ES256', - }) - - def get_validate_key(self): - with open(get_file_path('ec_public.json'), 'r') as f: - return json.load(f) - - -class PEMOpenIDCodeTest(RSAOpenIDCodeTest): - def config_app(self): - self.app.config.update({ - 'OAUTH2_JWT_ENABLED': True, - 'OAUTH2_JWT_ISS': 'Authlib', - 'OAUTH2_JWT_KEY_PATH': get_file_path('rsa_private.pem'), - 'OAUTH2_JWT_ALG': 'RS256', - }) - - def get_validate_key(self): - with open(get_file_path('rsa_public.pem'), 'r') as f: - return f.read() diff --git a/tests/flask/test_oauth2/test_openid_hybrid_grant.py b/tests/flask/test_oauth2/test_openid_hybrid_grant.py index e596c4d48..5aeb3726b 100644 --- a/tests/flask/test_oauth2/test_openid_hybrid_grant.py +++ b/tests/flask/test_oauth2/test_openid_hybrid_grant.py @@ -1,285 +1,329 @@ +import pytest from flask import json -from authlib.common.urls import urlparse, url_decode -from authlib.jose import JWT -from authlib.oidc.core import HybridIDToken -from authlib.oidc.core.grants import ( - OpenIDCode as _OpenIDCode, - OpenIDHybridGrant as _OpenIDHybridGrant, -) + +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.jose import jwt from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ) -from .models import db, User, Client, exists_nonce -from .models import CodeGrantMixin, save_authorization_code -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - -JWT_CONFIG = {'iss': 'Authlib', 'key': 'secret', 'alg': 'HS256', 'exp': 3600} - - -class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): - def save_authorization_code(self, code, request): - return save_authorization_code(code, request) - - -class OpenIDCode(_OpenIDCode): - def get_jwt_config(self, grant): - return dict(JWT_CONFIG) - - def exists_nonce(self, nonce, request): - return exists_nonce(nonce, request) - - def generate_user_info(self, user, scopes): - return user.generate_user_info(scopes) - - -class OpenIDHybridGrant(_OpenIDHybridGrant): - def save_authorization_code(self, code, request): - return save_authorization_code(code, request) - - def get_jwt_config(self): - return dict(JWT_CONFIG) - - def exists_nonce(self, nonce, request): - return exists_nonce(nonce, request) - - def generate_user_info(self, user, scopes): - return user.generate_user_info(scopes) - - -class OpenIDCodeTest(TestCase): - def prepare_data(self): - server = create_authorization_server(self.app) - server.register_grant(OpenIDHybridGrant) - server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - - client = Client( - user_id=user.id, - client_id='hybrid-client', - client_secret='hybrid-secret', - ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b'], - 'scope': 'openid profile address', - 'response_types': ['code id_token', 'code token', 'code id_token token'], - 'grant_types': ['authorization_code'], - }) - db.session.add(client) - db.session.commit() - - def validate_claims(self, id_token, params): - jwt = JWT() - claims = jwt.decode( - id_token, 'secret', - claims_cls=HybridIDToken, - claims_params=params - ) - claims.validate() - - def test_invalid_client_id(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'invalid-client', - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_require_nonce(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code token', - 'scope': 'openid profile', - 'state': 'bar', - 'redirect_uri': 'https://a.b', - 'user_id': '1' - }) - self.assertIn('error=invalid_request', rv.location) - self.assertIn('nonce', rv.location) - - def test_invalid_response_type(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token invalid', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_grant') - - def test_invalid_scope(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('error=invalid_scope', rv.location) - - def test_access_denied(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - }) - self.assertIn('error=access_denied', rv.location) - - def test_code_access_token(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.location) - self.assertIn('access_token=', rv.location) - self.assertNotIn('id_token=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('hybrid-client', 'hybrid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) - - def test_code_id_token(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.location) - self.assertIn('id_token=', rv.location) - self.assertNotIn('access_token=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) - self.assertEqual(params['state'], 'bar') - - params['nonce'] = 'abc' - params['client_id'] = 'hybrid-client' - self.validate_claims(params['id_token'], params) - - code = params['code'] - headers = self.create_basic_header('hybrid-client', 'hybrid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) - - def test_code_id_token_access_token(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.location) - self.assertIn('id_token=', rv.location) - self.assertIn('access_token=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) - self.assertEqual(params['state'], 'bar') - self.validate_claims(params['id_token'], params) - - code = params['code'] - headers = self.create_basic_header('hybrid-client', 'hybrid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) - - def test_response_mode_query(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token token', - 'response_mode': 'query', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.location) - self.assertIn('id_token=', rv.location) - self.assertIn('access_token=', rv.location) - - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.assertEqual(params['state'], 'bar') - - def test_response_mode_form_post(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token token', - 'response_mode': 'form_post', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn(b'name="code"', rv.data) - self.assertIn(b'name="id_token"', rv.data) - self.assertIn(b'name="access_token"', rv.data) +from authlib.oidc.core import HybridIDToken +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant + +from .models import CodeGrantMixin +from .models import exists_nonce +from .models import save_authorization_code +from .oauth2_server import create_basic_header + +JWT_CONFIG = {"iss": "Authlib", "key": "secret", "alg": "HS256", "exp": 3600} + + +@pytest.fixture(autouse=True) +def server(server): + class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + class OpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant, client): + return dict(JWT_CONFIG) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + class OpenIDHybridGrant(_OpenIDHybridGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def get_jwt_config(self, client): + return dict(JWT_CONFIG) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + server.register_grant(OpenIDHybridGrant) + server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) + + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test"], + "scope": "openid profile address", + "response_types": [ + "code id_token", + "code token", + "code id_token token", + ], + "grant_types": ["authorization_code"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def validate_claims(id_token, params): + claims = jwt.decode( + id_token, "secret", claims_cls=HybridIDToken, claims_params=params + ) + claims.validate() + + +def test_invalid_client_id(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "invalid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_require_nonce(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code token", + "scope": "openid profile", + "state": "bar", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "error=invalid_request" in rv.location + assert "nonce" in rv.location + + +def test_invalid_response_type(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code id_token invalid", + "state": "bar", + "nonce": "abc", + "scope": "profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["error"] == "unsupported_response_type" + + +def test_invalid_scope(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code id_token", + "state": "bar", + "nonce": "abc", + "scope": "profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "error=invalid_scope" in rv.location + + +def test_access_denied(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + }, + ) + assert "error=access_denied" in rv.location + + +def test_code_access_token(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + assert "access_token=" in rv.location + assert "id_token=" not in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + assert params["state"] == "bar" + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" in resp + + +def test_code_id_token(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code id_token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + assert "id_token=" in rv.location + assert "access_token=" not in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + assert params["state"] == "bar" + + params["nonce"] = "abc" + params["client_id"] = "client-id" + validate_claims(params["id_token"], params) + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" in resp + + +def test_code_id_token_access_token(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code id_token token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + assert "id_token=" in rv.location + assert "access_token=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + assert params["state"] == "bar" + validate_claims(params["id_token"], params) + + code = params["code"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://client.test", + "code": code, + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" in resp + + +def test_response_mode_query(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code id_token token", + "response_mode": "query", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert "code=" in rv.location + assert "id_token=" in rv.location + assert "access_token=" in rv.location + + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + assert params["state"] == "bar" + + +def test_response_mode_form_post(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "client_id": "client-id", + "response_type": "code id_token token", + "response_mode": "form_post", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://client.test", + "user_id": "1", + }, + ) + assert b'name="code"' in rv.data + assert b'name="id_token"' in rv.data + assert b'name="access_token"' in rv.data diff --git a/tests/flask/test_oauth2/test_openid_implict_grant.py b/tests/flask/test_oauth2/test_openid_implict_grant.py index 6b66086bc..1a24d51af 100644 --- a/tests/flask/test_oauth2/test_openid_implict_grant.py +++ b/tests/flask/test_oauth2/test_openid_implict_grant.py @@ -1,172 +1,303 @@ -from authlib.jose import JWT +import pytest +from flask import current_app + +from authlib.common.urls import add_params_to_uri +from authlib.common.urls import url_decode +from authlib.common.urls import urlparse +from authlib.jose import JsonWebToken +from authlib.oauth2.rfc6749.requests import BasicOAuth2Payload +from authlib.oauth2.rfc6749.requests import OAuth2Request from authlib.oidc.core import ImplicitIDToken -from authlib.oidc.core.grants import ( - OpenIDImplicitGrant as _OpenIDImplicitGrant -) -from authlib.common.urls import urlparse, url_decode, add_params_to_uri -from .models import db, User, Client, exists_nonce -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class OpenIDImplicitGrant(_OpenIDImplicitGrant): - def get_jwt_config(self): - return dict(key='secret', alg='HS256', iss='Authlib', exp=3600) - - def generate_user_info(self, user, scopes): - return user.generate_user_info(scopes) - - def exists_nonce(self, nonce, request): - return exists_nonce(nonce, request) - - -class ImplicitTest(TestCase): - def prepare_data(self): - server = create_authorization_server(self.app) - server.register_grant(OpenIDImplicitGrant) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='implicit-client', - client_secret='', - ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b/c'], - 'scope': 'openid profile', - 'token_endpoint_auth_method': 'none', - 'response_types': ['id_token', 'id_token token'], - }) - self.authorize_url = ( - '/oauth/authorize?response_type=token' - '&client_id=implicit-client' - ) - db.session.add(client) - db.session.commit() - - def validate_claims(self, id_token, params): - jwt = JWT(['HS256']) - claims = jwt.decode( - id_token, 'secret', - claims_cls=ImplicitIDToken, - claims_params=params +from authlib.oidc.core.grants import OpenIDImplicitGrant as _OpenIDImplicitGrant + +from .models import Client +from .models import exists_nonce + +authorize_url = "/oauth/authorize?response_type=token&client_id=client-id" + + +@pytest.fixture(autouse=True) +def server(server): + class OpenIDImplicitGrant(_OpenIDImplicitGrant): + def get_jwt_config(self, client): + alg = current_app.config.get("OAUTH2_JWT_ALG", "HS256") + return dict(key="secret", alg=alg, iss="Authlib", exp=3600) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + server.register_grant(OpenIDImplicitGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/callback"], + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "response_types": ["id_token", "id_token token"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def validate_claims(id_token, params, alg="HS256"): + jwt = JsonWebToken([alg]) + claims = jwt.decode( + id_token, "secret", claims_cls=ImplicitIDToken, claims_params=params + ) + claims.validate() + return claims + + +def test_consent_view(test_client): + rv = test_client.get( + add_params_to_uri( + "/oauth/authorize", + { + "response_type": "id_token", + "client_id": "client-id", + "scope": "openid profile", + "state": "foo", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, ) - claims.validate() - - def test_consent_view(self): - self.prepare_data() - rv = self.client.get(add_params_to_uri('/oauth/authorize', { - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'foo', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - })) - self.assertIn(b'error=invalid_request', rv.data) - self.assertIn(b'nonce', rv.data) - - def test_require_nonce(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('error=invalid_request', rv.location) - self.assertIn('nonce', rv.location) - - def test_missing_openid_in_scope(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token token', - 'client_id': 'implicit-client', - 'scope': 'profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('error=invalid_scope', rv.location) - - def test_denied(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - }) - self.assertIn('error=access_denied', rv.location) - - def test_authorize_access_token(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('access_token=', rv.location) - self.assertIn('id_token=', rv.location) - self.assertIn('state=bar', rv.location) - params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) - self.validate_claims(params['id_token'], params) - - def test_authorize_id_token(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('id_token=', rv.location) - self.assertIn('state=bar', rv.location) - params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) - self.validate_claims(params['id_token'], params) - - def test_response_mode_query(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'response_mode': 'query', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('id_token=', rv.location) - self.assertIn('state=bar', rv.location) - params = dict(url_decode(urlparse.urlparse(rv.location).query)) - self.validate_claims(params['id_token'], params) - - def test_response_mode_form_post(self): - self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'response_mode': 'form_post', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn(b'name="id_token"', rv.data) - self.assertIn(b'name="state"', rv.data) + ) + assert "error=invalid_request" in rv.location + assert "nonce" in rv.location + + +def test_require_nonce(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "client-id", + "scope": "openid profile", + "state": "bar", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, + ) + assert "error=invalid_request" in rv.location + assert "nonce" in rv.location + + +def test_missing_openid_in_scope(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token token", + "client_id": "client-id", + "scope": "profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, + ) + assert "error=invalid_scope" in rv.location + + +def test_denied(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "client-id", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://client.test/callback", + }, + ) + assert "error=access_denied" in rv.location + + +def test_authorize_access_token(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token token", + "client_id": "client-id", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, + ) + assert "access_token=" in rv.location + assert "id_token=" in rv.location + assert "state=bar" in rv.location + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + validate_claims(params["id_token"], params) + + +def test_authorize_id_token(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "client-id", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, + ) + assert "id_token=" in rv.location + assert "state=bar" in rv.location + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + validate_claims(params["id_token"], params) + + +def test_response_mode_query(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "response_mode": "query", + "client_id": "client-id", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, + ) + assert "id_token=" in rv.location + assert "state=bar" in rv.location + params = dict(url_decode(urlparse.urlparse(rv.location).query)) + validate_claims(params["id_token"], params) + + +def test_response_mode_form_post(test_client): + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "response_mode": "form_post", + "client_id": "client-id", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + }, + ) + assert b'name="id_token"' in rv.data + assert b'name="state"' in rv.data + + +def test_client_metadata_custom_alg(test_client, app, db, client): + """If the client metadata 'id_token_signed_response_alg' is defined, + it should be used to sign id_tokens.""" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/callback"], + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "response_types": ["id_token", "id_token token"], + "id_token_signed_response_alg": "HS384", + } + ) + db.session.add(client) + db.session.commit() + + app.config["OAUTH2_JWT_ALG"] = None + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "client-id", + "scope": "openid profile", + "state": "foo", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + "nonce": "abc", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + claims = validate_claims(params["id_token"], params, "HS384") + assert claims.header["alg"] == "HS384" + + +def test_client_metadata_alg_none(test_client, app, db, client): + """The 'none' 'id_token_signed_response_alg' alg should be + forbidden in non implicit flows.""" + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/callback"], + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "response_types": ["id_token", "id_token token"], + "id_token_signed_response_alg": "none", + } + ) + db.session.add(client) + db.session.commit() + + app.config["OAUTH2_JWT_ALG"] = None + rv = test_client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "client-id", + "scope": "openid profile", + "state": "foo", + "redirect_uri": "https://client.test/callback", + "user_id": "1", + "nonce": "abc", + }, + ) + params = dict(url_decode(urlparse.urlparse(rv.location).fragment)) + assert params["error"] == "invalid_request" + + +def test_deprecated_get_jwt_config_signature(user): + """Using the old get_jwt_config(self) signature should emit a DeprecationWarning.""" + + class DeprecatedImplicitGrant(_OpenIDImplicitGrant): + def get_jwt_config(self): + return {"key": "secret", "alg": "HS256", "iss": "Authlib", "exp": 3600} + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + client = Client( + user_id=user.id, + client_id="deprecated-client", + client_secret="secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://client.test/callback"], + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "response_types": ["id_token"], + } + ) + + request = OAuth2Request("POST", "https://server.test/authorize") + request.payload = BasicOAuth2Payload({"nonce": "test-nonce"}) + request.client = client + request.user = user + + grant = DeprecatedImplicitGrant(request, client) + token = {"scope": "openid", "expires_in": 3600} + + with pytest.warns(DeprecationWarning, match="get_jwt_config.*version 1.8"): + grant.process_implicit_token(token) diff --git a/tests/flask/test_oauth2/test_password_grant.py b/tests/flask/test_oauth2/test_password_grant.py index 7e7d21500..2d7f1f32f 100644 --- a/tests/flask/test_oauth2/test_password_grant.py +++ b/tests/flask/test_oauth2/test_password_grant.py @@ -1,11 +1,40 @@ +import pytest from flask import json + from authlib.common.urls import add_params_to_uri from authlib.oauth2.rfc6749.grants import ( ResourceOwnerPasswordCredentialsGrant as _PasswordGrant, ) -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oidc.core import OpenIDToken + +from .models import User +from .oauth2_server import create_basic_header + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "openid profile", + "grant_types": ["password"], + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +class IDToken(OpenIDToken): + def get_jwt_config(self, grant, client): + return { + "iss": "Authlib", + "key": "secret", + "alg": "HS256", + } + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) class PasswordGrant(_PasswordGrant): @@ -15,152 +44,207 @@ def authenticate_user(self, username, password): return user -class PasswordTest(TestCase): - def prepare_data(self, grant_type='password'): - server = create_authorization_server(self.app) - server.register_grant(PasswordGrant) - self.server = server - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='password-client', - client_secret='password-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'grant_types': [grant_type], - 'redirect_uris': ['http://localhost/authorized'], - }) - db.session.add(client) - db.session.commit() - - def test_invalid_client(self): - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'password-client', 'invalid-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_invalid_scope(self): - self.prepare_data() - self.server.metadata = {'scopes_supported': ['profile']} - headers = self.create_basic_header( - 'password-client', 'password-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - 'scope': 'invalid', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_scope') - - def test_invalid_request(self): - self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' - ) - - rv = self.client.get(add_params_to_uri('/oauth/token', { - 'grant_type': 'password', - }), headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unsupported_grant_type') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'wrong', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_invalid_grant_type(self): - self.prepare_data(grant_type='invalid') - headers = self.create_basic_header( - 'password-client', 'password-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - def test_authorize_token(self): - self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_token_generator(self): - m = 'tests.flask.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('p-password.1.', resp['access_token']) - - def test_custom_expires_in(self): - self.app.config.update({ - 'OAUTH2_TOKEN_EXPIRES_IN': {'password': 1800} - }) - self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertEqual(resp['expires_in'], 1800) +def register_password_grant(server, extensions=None): + server.register_grant(PasswordGrant, extensions) + + +def test_invalid_client(test_client, server): + register_password_grant( + server, + ) + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("client-id", "invalid-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_invalid_scope(test_client, server): + register_password_grant( + server, + ) + server.scopes_supported = ["profile"] + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + "scope": "invalid", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_scope" + + +def test_invalid_request(test_client, server): + register_password_grant( + server, + ) + headers = create_basic_header("client-id", "client-secret") + + rv = test_client.get( + add_params_to_uri( + "/oauth/token", + { + "grant_type": "password", + }, + ), + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_grant_type" + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "wrong", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + +def test_invalid_grant_type(test_client, server, db, client): + register_password_grant(server) + client.set_client_metadata( + { + "scope": "openid profile", + "grant_types": ["invalid"], + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client) + db.session.commit() + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unauthorized_client" + + +def test_authorize_token(test_client, server): + register_password_grant( + server, + ) + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_token_generator(test_client, server, app): + m = "tests.flask.test_oauth2.oauth2_server:token_generator" + app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + server.load_config(app.config) + register_password_grant(server) + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "c-password.1." in resp["access_token"] + + +def test_custom_expires_in(test_client, server, app): + app.config.update({"OAUTH2_TOKEN_EXPIRES_IN": {"password": 1800}}) + server.load_config(app.config) + register_password_grant(server) + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert resp["expires_in"] == 1800 + + +def test_id_token_extension(test_client, server): + register_password_grant(server, extensions=[IDToken()]) + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + "scope": "openid profile", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "id_token" in resp diff --git a/tests/flask/test_oauth2/test_refresh_token.py b/tests/flask/test_oauth2/test_refresh_token.py index 7fe8e463d..fc62967d3 100644 --- a/tests/flask/test_oauth2/test_refresh_token.py +++ b/tests/flask/test_oauth2/test_refresh_token.py @@ -1,228 +1,262 @@ +import time + +import pytest from flask import json -from authlib.oauth2.rfc6749.grants import ( - RefreshTokenGrant as _RefreshTokenGrant, -) -from .models import db, User, Client, Token -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -class RefreshTokenGrant(_RefreshTokenGrant): - def authenticate_refresh_token(self, refresh_token): - item = Token.query.filter_by(refresh_token=refresh_token).first() - if item and not item.revoked and not item.is_refresh_token_expired(): - return item - - def authenticate_user(self, credential): - return User.query.get(credential.user_id) - - def revoke_old_credential(self, credential): - credential.revoked = True - db.session.add(credential) - db.session.commit() - - -class RefreshTokenTest(TestCase): - def prepare_data(self, grant_type='refresh_token'): - server = create_authorization_server(self.app) - server.register_grant(RefreshTokenGrant) - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='refresh-client', - client_secret='refresh-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'grant_types': [grant_type], - 'redirect_uris': ['http://localhost/authorized'], - }) - db.session.add(client) - db.session.commit() - - def create_token(self, scope='profile', user_id=1): - token = Token( - user_id=user_id, - client_id='refresh-client', - token_type='bearer', - access_token='a1', - refresh_token='r1', - scope=scope, - expires_in=3600, - ) - db.session.add(token) - db.session.commit() - - def test_invalid_client(self): - self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'foo', - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'invalid-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'foo', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'refresh-client', 'invalid-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'foo', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_invalid_refresh_token(self): - self.prepare_data() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - self.assertIn('Missing', resp['error_description']) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'foo', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_grant') - - def test_invalid_scope(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'invalid', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_scope') - - def test_invalid_scope_none(self): - self.prepare_data() - self.create_token(scope=None) - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'invalid', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_scope') - - def test_invalid_user(self): - self.prepare_data() - self.create_token(user_id=5) - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - def test_invalid_grant_type(self): - self.prepare_data(grant_type='invalid') - self.create_token() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') - - def test_authorize_token_no_scope(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_authorize_token_scope(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - def test_revoke_old_credential(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - 'scope': 'profile', - }, headers=headers) - self.assertEqual(rv.status_code, 400) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_grant') - - def test_token_generator(self): - m = 'tests.flask.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'refresh-client', 'refresh-secret' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'refresh_token', - 'refresh_token': 'r1', - }, headers=headers) - resp = json.loads(rv.data) - self.assertIn('access_token', resp) - self.assertIn('r-refresh_token.1.', resp['access_token']) + +from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant + +from .models import Token +from .models import User +from .models import db +from .oauth2_server import create_basic_header + + +@pytest.fixture(autouse=True) +def server(server): + class RefreshTokenGrant(_RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + item = Token.query.filter_by(refresh_token=refresh_token).first() + if item and item.is_refresh_token_active(): + return item + + def authenticate_user(self, credential): + return db.session.get(User, credential.user_id) + + def revoke_old_credential(self, credential): + now = int(time.time()) + credential.access_token_revoked_at = now + credential.refresh_token_revoked_at = now + db.session.add(credential) + db.session.commit() + + server.register_grant(RefreshTokenGrant) + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "grant_types": ["refresh_token"], + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_invalid_client(test_client): + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "foo", + }, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("invalid-client", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "foo", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("client-id", "invalid-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "foo", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_invalid_refresh_token(test_client): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + assert "Missing" in resp["error_description"] + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "foo", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" + + +def test_invalid_scope(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "invalid", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_scope" + + +def test_invalid_scope_none(test_client, token): + token.scope = None + db.session.add(token) + db.session.commit() + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "invalid", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_scope" + + +def test_invalid_user(test_client, token): + token.user_id = 5 + db.session.add(token) + db.session.commit() + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + +def test_invalid_grant_type(test_client, client, db, token): + client.set_client_metadata( + { + "scope": "profile", + "grant_types": ["invalid"], + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client) + db.session.commit() + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unauthorized_client" + + +def test_authorize_token_no_scope(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_authorize_token_scope(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + +def test_revoke_old_credential(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + "scope": "profile", + }, + headers=headers, + ) + assert rv.status_code == 400 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" + + +def test_token_generator(test_client, token, app, server): + m = "tests.flask.test_oauth2.oauth2_server:token_generator" + app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + server.load_config(app.config) + + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": "r1", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert "access_token" in resp + assert "c-refresh_token.1." in resp["access_token"] diff --git a/tests/flask/test_oauth2/test_revocation_endpoint.py b/tests/flask/test_oauth2/test_revocation_endpoint.py index 70956281a..4339b013f 100644 --- a/tests/flask/test_oauth2/test_revocation_endpoint.py +++ b/tests/flask/test_oauth2/test_revocation_endpoint.py @@ -1,122 +1,145 @@ +import pytest from flask import json + from authlib.integrations.sqla_oauth2 import create_revocation_endpoint -from .models import db, User, Client, Token -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server - - -RevocationEndpoint = create_revocation_endpoint(db.session, Token) - - -class RevokeTokenTest(TestCase): - def prepare_data(self): - app = self.app - server = create_authorization_server(app) - server.register_endpoint(RevocationEndpoint) - - @app.route('/oauth/revoke', methods=['POST']) - def revoke_token(): - return server.create_endpoint_response('revocation') - - user = User(username='foo') - db.session.add(user) - db.session.commit() - client = Client( - user_id=user.id, - client_id='revoke-client', - client_secret='revoke-secret', - ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - }) - db.session.add(client) - db.session.commit() - - def create_token(self): - token = Token( - user_id=1, - client_id='revoke-client', - token_type='bearer', - access_token='a1', - refresh_token='r1', - scope='profile', - expires_in=3600, - ) - db.session.add(token) - db.session.commit() - - def test_invalid_client(self): - self.prepare_data() - rv = self.client.post('/oauth/revoke') - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = {'Authorization': 'invalid token_string'} - rv = self.client.post('/oauth/revoke', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'invalid-client', 'revoke-secret' - ) - rv = self.client.post('/oauth/revoke', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'revoke-client', 'invalid-secret' - ) - rv = self.client.post('/oauth/revoke', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_client') - - def test_invalid_token(self): - self.prepare_data() - headers = self.create_basic_header( - 'revoke-client', 'revoke-secret' - ) - rv = self.client.post('/oauth/revoke', headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/revoke', data={ - 'token': 'invalid-token', - }, headers=headers) - self.assertEqual(rv.status_code, 200) - - rv = self.client.post('/oauth/revoke', data={ - 'token': 'a1', - 'token_type_hint': 'unsupported_token_type', - }, headers=headers) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unsupported_token_type') - - rv = self.client.post('/oauth/revoke', data={ - 'token': 'a1', - 'token_type_hint': 'refresh_token', - }, headers=headers) - self.assertEqual(rv.status_code, 200) - - def test_revoke_token_with_hint(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'revoke-client', 'revoke-secret' - ) - rv = self.client.post('/oauth/revoke', data={ - 'token': 'a1', - 'token_type_hint': 'access_token', - }, headers=headers) - self.assertEqual(rv.status_code, 200) - - def test_revoke_token_without_hint(self): - self.prepare_data() - self.create_token() - headers = self.create_basic_header( - 'revoke-client', 'revoke-secret' - ) - rv = self.client.post('/oauth/revoke', data={ - 'token': 'a1', - }, headers=headers) - self.assertEqual(rv.status_code, 200) + +from .models import Client +from .models import Token +from .models import db +from .oauth2_server import create_basic_header + + +@pytest.fixture(autouse=True) +def server(server, app): + RevocationEndpoint = create_revocation_endpoint(db.session, Token) + server.register_endpoint(RevocationEndpoint) + + @app.route("/oauth/revoke", methods=["POST"]) + def revoke_token(): + return server.create_endpoint_response("revocation") + + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +def test_invalid_client(test_client): + rv = test_client.post("/oauth/revoke") + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = {"Authorization": "invalid token_string"} + rv = test_client.post("/oauth/revoke", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("invalid-client", "client-secret") + rv = test_client.post("/oauth/revoke", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + headers = create_basic_header("client-id", "invalid-secret") + rv = test_client.post("/oauth/revoke", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_client" + + +def test_invalid_token(test_client): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post("/oauth/revoke", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "invalid_request" + + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "invalid-token", + }, + headers=headers, + ) + assert rv.status_code == 200 + + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "a1", + "token_type_hint": "unsupported_token_type", + }, + headers=headers, + ) + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" + + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "a1", + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + assert rv.status_code == 200 + + +def test_revoke_token_with_hint(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "a1", + "token_type_hint": "access_token", + }, + headers=headers, + ) + assert rv.status_code == 200 + + +def test_revoke_token_without_hint(test_client, token): + headers = create_basic_header("client-id", "client-secret") + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "a1", + }, + headers=headers, + ) + assert rv.status_code == 200 + + +def test_revoke_token_bound_to_client(test_client, token): + client2 = Client( + user_id=1, + client_id="client-id-2", + client_secret="client-secret-2", + ) + client2.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client2) + db.session.commit() + + headers = create_basic_header("client-id-2", "client-secret-2") + rv = test_client.post( + "/oauth/revoke", + data={ + "token": "a1", + }, + headers=headers, + ) + assert rv.status_code == 400 + resp = json.loads(rv.data) + assert resp["error"] == "invalid_grant" diff --git a/tests/flask/test_oauth2/test_userinfo.py b/tests/flask/test_oauth2/test_userinfo.py new file mode 100644 index 000000000..bc5d1eb40 --- /dev/null +++ b/tests/flask/test_oauth2/test_userinfo.py @@ -0,0 +1,329 @@ +import pytest +from flask import json +from joserfc import jwt +from joserfc.jwk import KeySet + +import authlib.oidc.core as oidc_core +from authlib.integrations.flask_oauth2 import ResourceProtector +from authlib.integrations.sqla_oauth2 import create_bearer_token_validator +from tests.util import read_file_path + +from .models import Token + + +@pytest.fixture(autouse=True) +def server(server, app, db): + class UserInfoEndpoint(oidc_core.UserInfoEndpoint): + def get_issuer(self) -> str: + return "https://provider.test" + + def generate_user_info(self, user, scope): + return user.generate_user_info().filter(scope) + + def resolve_private_key(self): + return read_file_path("jwks_private.json") + + BearerTokenValidator = create_bearer_token_validator(db.session, Token) + resource_protector = ResourceProtector() + resource_protector.register_token_validator(BearerTokenValidator()) + server.register_endpoint(UserInfoEndpoint(resource_protector=resource_protector)) + + @app.route("/oauth/userinfo", methods=["GET", "POST"]) + def userinfo(): + return server.create_endpoint_response("userinfo") + + return server + + +@pytest.fixture(autouse=True) +def client(client, db): + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + } + ) + db.session.add(client) + db.session.commit() + return client + + +@pytest.fixture(autouse=True) +def token(db): + token = Token( + user_id=1, + client_id="client-id", + token_type="bearer", + access_token="access-token", + refresh_token="r1", + scope="openid", + expires_in=3600, + ) + db.session.add(token) + db.session.commit() + yield token + db.session.delete(token) + + +def test_get(test_client, db, token): + """The UserInfo Endpoint MUST support the use of the HTTP GET and HTTP POST methods defined in RFC 7231 [RFC7231]. + The UserInfo Endpoint MUST accept Access Tokens as OAuth 2.0 Bearer Token Usage [RFC6750].""" + + token.scope = "openid profile email address phone" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + assert rv.headers["Content-Type"] == "application/json" + + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + "address": { + "country": "USA", + "formatted": "742 Evergreen Terrace, Springfield", + "locality": "Springfield", + "postal_code": "1245", + "region": "Unknown", + "street_address": "742 Evergreen Terrace", + }, + "birthdate": "2000-12-01", + "email": "janedoe@example.com", + "email_verified": True, + "family_name": "Doe", + "gender": "female", + "given_name": "Jane", + "locale": "fr-FR", + "middle_name": "Middle", + "name": "foo", + "nickname": "Jany", + "phone_number": "+1 (425) 555-1212", + "phone_number_verified": False, + "picture": "https://resource.test/janedoe/me.jpg", + "preferred_username": "j.doe", + "profile": "https://resource.test/janedoe", + "updated_at": 1745315119, + "website": "https://resource.test", + "zoneinfo": "Europe/Paris", + } + + +def test_post(test_client, db, token): + """The UserInfo Endpoint MUST support the use of the HTTP GET and HTTP POST methods defined in RFC 7231 [RFC7231]. + The UserInfo Endpoint MUST accept Access Tokens as OAuth 2.0 Bearer Token Usage [RFC6750].""" + + token.scope = "openid profile email address phone" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.post("/oauth/userinfo", headers=headers) + assert rv.headers["Content-Type"] == "application/json" + + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + "address": { + "country": "USA", + "formatted": "742 Evergreen Terrace, Springfield", + "locality": "Springfield", + "postal_code": "1245", + "region": "Unknown", + "street_address": "742 Evergreen Terrace", + }, + "birthdate": "2000-12-01", + "email": "janedoe@example.com", + "email_verified": True, + "family_name": "Doe", + "gender": "female", + "given_name": "Jane", + "locale": "fr-FR", + "middle_name": "Middle", + "name": "foo", + "nickname": "Jany", + "phone_number": "+1 (425) 555-1212", + "phone_number_verified": False, + "picture": "https://resource.test/janedoe/me.jpg", + "preferred_username": "j.doe", + "profile": "https://resource.test/janedoe", + "updated_at": 1745315119, + "website": "https://resource.test", + "zoneinfo": "Europe/Paris", + } + + +def test_no_token(test_client): + rv = test_client.post("/oauth/userinfo") + resp = json.loads(rv.data) + assert resp["error"] == "missing_authorization" + + +def test_bad_token(test_client): + headers = {"Authorization": "invalid token_string"} + rv = test_client.post("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "unsupported_token_type" + + +def test_token_has_bad_scope(test_client, db, token): + """Test that tokens without 'openid' scope cannot access the userinfo endpoint.""" + + token.scope = "foobar" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.post("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp["error"] == "insufficient_scope" + + +def test_scope_minimum(test_client): + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + } + + +def test_scope_profile(test_client, db, token): + token.scope = "openid profile" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + "birthdate": "2000-12-01", + "family_name": "Doe", + "gender": "female", + "given_name": "Jane", + "locale": "fr-FR", + "middle_name": "Middle", + "name": "foo", + "nickname": "Jany", + "picture": "https://resource.test/janedoe/me.jpg", + "preferred_username": "j.doe", + "profile": "https://resource.test/janedoe", + "updated_at": 1745315119, + "website": "https://resource.test", + "zoneinfo": "Europe/Paris", + } + + +def test_scope_address(test_client, db, token): + token.scope = "openid address" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + "address": { + "country": "USA", + "formatted": "742 Evergreen Terrace, Springfield", + "locality": "Springfield", + "postal_code": "1245", + "region": "Unknown", + "street_address": "742 Evergreen Terrace", + }, + } + + +def test_scope_email(test_client, db, token): + token.scope = "openid email" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + "email": "janedoe@example.com", + "email_verified": True, + } + + +def test_scope_phone(test_client, db, token): + token.scope = "openid phone" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + resp = json.loads(rv.data) + assert resp == { + "sub": "1", + "phone_number": "+1 (425) 555-1212", + "phone_number_verified": False, + } + + +@pytest.mark.skip +def test_scope_signed_unsecured(test_client, db, token, client): + """When userinfo_signed_response_alg is set as client metadata, the userinfo response must be a JWT.""" + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "userinfo_signed_response_alg": "none", + } + ) + db.session.add(client) + db.session.commit() + + token.scope = "openid email" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + assert rv.headers["Content-Type"] == "application/jwt" + + # specify that we support "none" + token = jwt.decode(rv.data, None, algorithms=["none"]) + assert token.claims == { + "sub": "1", + "iss": "https://provider.test", + "aud": "client-id", + "email": "janedoe@example.com", + "email_verified": True, + } + + +def test_scope_signed_secured(test_client, client, token, db): + """When userinfo_signed_response_alg is set as client metadata and not none, the userinfo response must be signed.""" + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["https://client.test/authorized"], + "userinfo_signed_response_alg": "RS256", + } + ) + db.session.add(client) + db.session.commit() + + token.scope = "openid email" + db.session.add(token) + db.session.commit() + + headers = {"Authorization": "Bearer access-token"} + rv = test_client.get("/oauth/userinfo", headers=headers) + assert rv.headers["Content-Type"] == "application/jwt" + + pub_key = KeySet.import_key_set(read_file_path("jwks_public.json")) + token = jwt.decode(rv.data, pub_key) + assert token.claims == { + "sub": "1", + "iss": "https://provider.test", + "aud": "client-id", + "email": "janedoe@example.com", + "email_verified": True, + } diff --git a/tests/jose/__init__.py b/tests/jose/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/jose/test_chacha20.py b/tests/jose/test_chacha20.py new file mode 100644 index 000000000..aea4f110d --- /dev/null +++ b/tests/jose/test_chacha20.py @@ -0,0 +1,73 @@ +import pytest +from cryptography.exceptions import InvalidTag + +from authlib.jose import JsonWebEncryption +from authlib.jose import OctKey +from authlib.jose.drafts import register_jwe_draft + +register_jwe_draft(JsonWebEncryption) + + +def test_dir_alg_c20p(): + jwe = JsonWebEncryption() + key = OctKey.generate_key(256, is_private=True) + protected = {"alg": "dir", "enc": "C20P"} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + key2 = OctKey.generate_key(128, is_private=True) + with pytest.raises(InvalidTag): + jwe.deserialize_compact(data, key2) + + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", key2) + + +def test_dir_alg_xc20p(): + pytest.importorskip("Cryptodome.Cipher.ChaCha20_Poly1305") + + jwe = JsonWebEncryption() + key = OctKey.generate_key(256, is_private=True) + protected = {"alg": "dir", "enc": "XC20P"} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + key2 = OctKey.generate_key(128, is_private=True) + with pytest.raises(ValueError): + jwe.deserialize_compact(data, key2) + + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", key2) + + +def test_xc20p_content_encryption_decryption(): + pytest.importorskip("Cryptodome.Cipher.ChaCha20_Poly1305") + + # https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-xchacha-03#appendix-A.3.1 + enc = JsonWebEncryption.ENC_REGISTRY["XC20P"] + + plaintext = bytes.fromhex( + "4c616469657320616e642047656e746c656d656e206f662074686520636c6173" + + "73206f66202739393a204966204920636f756c64206f6666657220796f75206f" + + "6e6c79206f6e652074697020666f7220746865206675747572652c2073756e73" + + "637265656e20776f756c642062652069742e" + ) + aad = bytes.fromhex("50515253c0c1c2c3c4c5c6c7") + key = bytes.fromhex( + "808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f" + ) + iv = bytes.fromhex("404142434445464748494a4b4c4d4e4f5051525354555657") + + ciphertext, tag = enc.encrypt(plaintext, aad, iv, key) + assert ciphertext == bytes.fromhex( + "bd6d179d3e83d43b9576579493c0e939572a1700252bfaccbed2902c21396cbb" + + "731c7f1b0b4aa6440bf3a82f4eda7e39ae64c6708c54c216cb96b72e1213b452" + + "2f8c9ba40db5d945b11b69b982c1bb9e3f3fac2bc369488f76b2383565d3fff9" + + "21f9664c97637da9768812f615c68b13b52e" + ) + assert tag == bytes.fromhex("c0875924c1c7987947deafd8780acf49") + + decrypted_plaintext = enc.decrypt(ciphertext, aad, iv, tag, key) + assert decrypted_plaintext == plaintext diff --git a/tests/jose/test_ecdh_1pu.py b/tests/jose/test_ecdh_1pu.py new file mode 100644 index 000000000..9da5e92f1 --- /dev/null +++ b/tests/jose/test_ecdh_1pu.py @@ -0,0 +1,1626 @@ +from collections import OrderedDict + +import pytest +from cryptography.hazmat.primitives.keywrap import InvalidUnwrap + +from authlib.common.encoding import json_b64encode +from authlib.common.encoding import json_loads +from authlib.common.encoding import to_bytes +from authlib.common.encoding import urlsafe_b64decode +from authlib.common.encoding import urlsafe_b64encode +from authlib.jose import ECKey +from authlib.jose import JsonWebEncryption +from authlib.jose import OKPKey +from authlib.jose.drafts import register_jwe_draft +from authlib.jose.errors import InvalidAlgorithmForMultipleRecipientsMode +from authlib.jose.errors import InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError +from authlib.jose.rfc7516.models import JWEHeader + +register_jwe_draft(JsonWebEncryption) + + +def test_ecdh_1pu_key_agreement_computation_appx_a(): + # https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04#appendix-A + alice_static_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_static_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + alice_ephemeral_key = { + "kty": "EC", + "crv": "P-256", + "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", + "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", + "d": "0_NxaRPUMQoAJt50Gz8YiTr8gRTwyEaCumd-MToTmIo", + } + + headers = { + "alg": "ECDH-1PU", + "enc": "A256GCM", + "apu": "QWxpY2U", + "apv": "Qm9i", + "epk": { + "kty": "EC", + "crv": "P-256", + "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", + "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", + }, + } + + alg = JsonWebEncryption.ALG_REGISTRY["ECDH-1PU"] + enc = JsonWebEncryption.ENC_REGISTRY["A256GCM"] + + alice_static_key = alg.prepare_key(alice_static_key) + bob_static_key = alg.prepare_key(bob_static_key) + alice_ephemeral_key = alg.prepare_key(alice_ephemeral_key) + + alice_static_pubkey = alice_static_key.get_op_key("wrapKey") + bob_static_pubkey = bob_static_key.get_op_key("wrapKey") + alice_ephemeral_pubkey = alice_ephemeral_key.get_op_key("wrapKey") + + # Derived key computation at Alice + + # Step-by-step methods verification + _shared_key_e_at_alice = alice_ephemeral_key.exchange_shared_key(bob_static_pubkey) + assert ( + _shared_key_e_at_alice + == b"\x9e\x56\xd9\x1d\x81\x71\x35\xd3\x72\x83\x42\x83\xbf\x84\x26\x9c" + + b"\xfb\x31\x6e\xa3\xda\x80\x6a\x48\xf6\xda\xa7\x79\x8c\xfe\x90\xc4" + ) + + _shared_key_s_at_alice = alice_static_key.exchange_shared_key(bob_static_pubkey) + assert ( + _shared_key_s_at_alice + == b"\xe3\xca\x34\x74\x38\x4c\x9f\x62\xb3\x0b\xfd\x4c\x68\x8b\x3e\x7d" + + b"\x41\x10\xa1\xb4\xba\xdc\x3c\xc5\x4e\xf7\xb8\x12\x41\xef\xd5\x0d" + ) + + _shared_key_at_alice = alg.compute_shared_key( + _shared_key_e_at_alice, _shared_key_s_at_alice + ) + assert ( + _shared_key_at_alice + == b"\x9e\x56\xd9\x1d\x81\x71\x35\xd3\x72\x83\x42\x83\xbf\x84\x26\x9c" + + b"\xfb\x31\x6e\xa3\xda\x80\x6a\x48\xf6\xda\xa7\x79\x8c\xfe\x90\xc4" + + b"\xe3\xca\x34\x74\x38\x4c\x9f\x62\xb3\x0b\xfd\x4c\x68\x8b\x3e\x7d" + + b"\x41\x10\xa1\xb4\xba\xdc\x3c\xc5\x4e\xf7\xb8\x12\x41\xef\xd5\x0d" + ) + + _fixed_info_at_alice = alg.compute_fixed_info(headers, enc.key_size, None) + assert ( + _fixed_info_at_alice + == b"\x00\x00\x00\x07\x41\x32\x35\x36\x47\x43\x4d\x00\x00\x00\x05\x41" + + b"\x6c\x69\x63\x65\x00\x00\x00\x03\x42\x6f\x62\x00\x00\x01\x00" + ) + + _dk_at_alice = alg.compute_derived_key( + _shared_key_at_alice, _fixed_info_at_alice, enc.key_size + ) + assert ( + _dk_at_alice + == b"\x6c\xaf\x13\x72\x3d\x14\x85\x0a\xd4\xb4\x2c\xd6\xdd\xe9\x35\xbf" + + b"\xfd\x2f\xff\x00\xa9\xba\x70\xde\x05\xc2\x03\xa5\xe1\x72\x2c\xa7" + ) + assert ( + urlsafe_b64encode(_dk_at_alice) + == b"bK8Tcj0UhQrUtCzW3ek1v_0v_wCpunDeBcIDpeFyLKc" + ) + + # All-in-one method verification + dk_at_alice = alg.deliver_at_sender( + alice_static_key, + alice_ephemeral_key, + bob_static_pubkey, + headers, + enc.key_size, + None, + ) + assert ( + urlsafe_b64encode(dk_at_alice) == b"bK8Tcj0UhQrUtCzW3ek1v_0v_wCpunDeBcIDpeFyLKc" + ) + + # Derived key computation at Bob + + # Step-by-step methods verification + _shared_key_e_at_bob = bob_static_key.exchange_shared_key(alice_ephemeral_pubkey) + assert _shared_key_e_at_bob == _shared_key_e_at_alice + + _shared_key_s_at_bob = bob_static_key.exchange_shared_key(alice_static_pubkey) + assert _shared_key_s_at_bob == _shared_key_s_at_alice + + _shared_key_at_bob = alg.compute_shared_key( + _shared_key_e_at_bob, _shared_key_s_at_bob + ) + assert _shared_key_at_bob == _shared_key_at_alice + + _fixed_info_at_bob = alg.compute_fixed_info(headers, enc.key_size, None) + assert _fixed_info_at_bob == _fixed_info_at_alice + + _dk_at_bob = alg.compute_derived_key( + _shared_key_at_bob, _fixed_info_at_bob, enc.key_size + ) + assert _dk_at_bob == _dk_at_alice + + # All-in-one method verification + dk_at_bob = alg.deliver_at_recipient( + bob_static_key, + alice_static_pubkey, + alice_ephemeral_pubkey, + headers, + enc.key_size, + None, + ) + assert dk_at_bob == dk_at_alice + + +def test_ecdh_1pu_key_agreement_computation_appx_b(): + # https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04#appendix-B + alice_static_key = { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + bob_static_key = { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + charlie_static_key = { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + alice_ephemeral_key = { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + "d": "x8EVZH4Fwk673_mUujnliJoSrLz0zYzzCWp5GUX2fc8", + } + + protected = OrderedDict( + { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": OrderedDict( + { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + } + ), + } + ) + + cek = ( + b"\xff\xfe\xfd\xfc\xfb\xfa\xf9\xf8\xf7\xf6\xf5\xf4\xf3\xf2\xf1\xf0" + b"\xef\xee\xed\xec\xeb\xea\xe9\xe8\xe7\xe6\xe5\xe4\xe3\xe2\xe1\xe0" + b"\xdf\xde\xdd\xdc\xdb\xda\xd9\xd8\xd7\xd6\xd5\xd4\xd3\xd2\xd1\xd0" + b"\xcf\xce\xcd\xcc\xcb\xca\xc9\xc8\xc7\xc6\xc5\xc4\xc3\xc2\xc1\xc0" + ) + + iv = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f" + + payload = b"Three is a magic number." + + alg = JsonWebEncryption.ALG_REGISTRY["ECDH-1PU+A128KW"] + enc = JsonWebEncryption.ENC_REGISTRY["A256CBC-HS512"] + + alice_static_key = OKPKey.import_key(alice_static_key) + bob_static_key = OKPKey.import_key(bob_static_key) + charlie_static_key = OKPKey.import_key(charlie_static_key) + alice_ephemeral_key = OKPKey.import_key(alice_ephemeral_key) + + alice_static_pubkey = alice_static_key.get_op_key("wrapKey") + bob_static_pubkey = bob_static_key.get_op_key("wrapKey") + charlie_static_pubkey = charlie_static_key.get_op_key("wrapKey") + alice_ephemeral_pubkey = alice_ephemeral_key.get_op_key("wrapKey") + + protected_segment = json_b64encode(protected) + aad = to_bytes(protected_segment, "ascii") + + ciphertext, tag = enc.encrypt(payload, aad, iv, cek) + assert ( + urlsafe_b64encode(ciphertext) == b"Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw" + ) + assert urlsafe_b64encode(tag) == b"HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ" + + # Derived key computation at Alice for Bob + + # Step-by-step methods verification + _shared_key_e_at_alice_for_bob = alice_ephemeral_key.exchange_shared_key( + bob_static_pubkey + ) + assert ( + _shared_key_e_at_alice_for_bob + == b"\x32\x81\x08\x96\xe0\xfe\x4d\x57\x0e\xd1\xac\xfc\xed\xf6\x71\x17" + + b"\xdc\x19\x4e\xd5\xda\xac\x21\xd8\xff\x7a\xf3\x24\x46\x94\x89\x7f" + ) + + _shared_key_s_at_alice_for_bob = alice_static_key.exchange_shared_key( + bob_static_pubkey + ) + assert ( + _shared_key_s_at_alice_for_bob + == b"\x21\x57\x61\x2c\x90\x48\xed\xfa\xe7\x7c\xb2\xe4\x23\x71\x40\x60" + + b"\x59\x67\xc0\x5c\x7f\x77\xa4\x8e\xea\xf2\xcf\x29\xa5\x73\x7c\x4a" + ) + + _shared_key_at_alice_for_bob = alg.compute_shared_key( + _shared_key_e_at_alice_for_bob, _shared_key_s_at_alice_for_bob + ) + assert ( + _shared_key_at_alice_for_bob + == b"\x32\x81\x08\x96\xe0\xfe\x4d\x57\x0e\xd1\xac\xfc\xed\xf6\x71\x17" + + b"\xdc\x19\x4e\xd5\xda\xac\x21\xd8\xff\x7a\xf3\x24\x46\x94\x89\x7f" + + b"\x21\x57\x61\x2c\x90\x48\xed\xfa\xe7\x7c\xb2\xe4\x23\x71\x40\x60" + + b"\x59\x67\xc0\x5c\x7f\x77\xa4\x8e\xea\xf2\xcf\x29\xa5\x73\x7c\x4a" + ) + + _fixed_info_at_alice_for_bob = alg.compute_fixed_info(protected, alg.key_size, tag) + assert ( + _fixed_info_at_alice_for_bob + == b"\x00\x00\x00\x0f\x45\x43\x44\x48\x2d\x31\x50\x55\x2b\x41\x31\x32" + + b"\x38\x4b\x57\x00\x00\x00\x05\x41\x6c\x69\x63\x65\x00\x00\x00\x0f" + + b"\x42\x6f\x62\x20\x61\x6e\x64\x20\x43\x68\x61\x72\x6c\x69\x65\x00" + + b"\x00\x00\x80\x00\x00\x00\x20\x1c\xb6\xf8\x7d\x39\x66\xf2\xca\x46" + + b"\x9a\x28\xf7\x47\x23\xac\xda\x02\x78\x0e\x91\xcc\xe2\x18\x55\x47" + + b"\x07\x45\xfe\x11\x9b\xdd\x64" + ) + + _dk_at_alice_for_bob = alg.compute_derived_key( + _shared_key_at_alice_for_bob, _fixed_info_at_alice_for_bob, alg.key_size + ) + assert ( + _dk_at_alice_for_bob + == b"\xdf\x4c\x37\xa0\x66\x83\x06\xa1\x1e\x3d\x6b\x00\x74\xb5\xd8\xdf" + ) + + # All-in-one method verification + dk_at_alice_for_bob = alg.deliver_at_sender( + alice_static_key, + alice_ephemeral_key, + bob_static_pubkey, + protected, + alg.key_size, + tag, + ) + assert ( + dk_at_alice_for_bob + == b"\xdf\x4c\x37\xa0\x66\x83\x06\xa1\x1e\x3d\x6b\x00\x74\xb5\xd8\xdf" + ) + + kek_at_alice_for_bob = alg.aeskw.prepare_key(dk_at_alice_for_bob) + wrapped_for_bob = alg.aeskw.wrap_cek(cek, kek_at_alice_for_bob) + ek_for_bob = wrapped_for_bob["ek"] + assert ( + urlsafe_b64encode(ek_for_bob) + == b"pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN" + ) + + # Derived key computation at Alice for Charlie + + # Step-by-step methods verification + _shared_key_e_at_alice_for_charlie = alice_ephemeral_key.exchange_shared_key( + charlie_static_pubkey + ) + assert ( + _shared_key_e_at_alice_for_charlie + == b"\x89\xdc\xfe\x4c\x37\xc1\xdc\x02\x71\xf3\x46\xb5\xb3\xb1\x9c\x3b" + + b"\x70\x5c\xa2\xa7\x2f\x9a\x23\x77\x85\xc3\x44\x06\xfc\xb7\x5f\x10" + ) + + _shared_key_s_at_alice_for_charlie = alice_static_key.exchange_shared_key( + charlie_static_pubkey + ) + assert ( + _shared_key_s_at_alice_for_charlie + == b"\x78\xfe\x63\xfc\x66\x1c\xf8\xd1\x8f\x92\xa8\x42\x2a\x64\x18\xe4" + + b"\xed\x5e\x20\xa9\x16\x81\x85\xfd\xee\xdc\xa1\xc3\xd8\xe6\xa6\x1c" + ) + + _shared_key_at_alice_for_charlie = alg.compute_shared_key( + _shared_key_e_at_alice_for_charlie, _shared_key_s_at_alice_for_charlie + ) + assert ( + _shared_key_at_alice_for_charlie + == b"\x89\xdc\xfe\x4c\x37\xc1\xdc\x02\x71\xf3\x46\xb5\xb3\xb1\x9c\x3b" + + b"\x70\x5c\xa2\xa7\x2f\x9a\x23\x77\x85\xc3\x44\x06\xfc\xb7\x5f\x10" + + b"\x78\xfe\x63\xfc\x66\x1c\xf8\xd1\x8f\x92\xa8\x42\x2a\x64\x18\xe4" + + b"\xed\x5e\x20\xa9\x16\x81\x85\xfd\xee\xdc\xa1\xc3\xd8\xe6\xa6\x1c" + ) + + _fixed_info_at_alice_for_charlie = alg.compute_fixed_info( + protected, alg.key_size, tag + ) + assert _fixed_info_at_alice_for_charlie == _fixed_info_at_alice_for_bob + + _dk_at_alice_for_charlie = alg.compute_derived_key( + _shared_key_at_alice_for_charlie, + _fixed_info_at_alice_for_charlie, + alg.key_size, + ) + assert ( + _dk_at_alice_for_charlie + == b"\x57\xd8\x12\x6f\x1b\x7e\xc4\xcc\xb0\x58\x4d\xac\x03\xcb\x27\xcc" + ) + + # All-in-one method verification + dk_at_alice_for_charlie = alg.deliver_at_sender( + alice_static_key, + alice_ephemeral_key, + charlie_static_pubkey, + protected, + alg.key_size, + tag, + ) + assert ( + dk_at_alice_for_charlie + == b"\x57\xd8\x12\x6f\x1b\x7e\xc4\xcc\xb0\x58\x4d\xac\x03\xcb\x27\xcc" + ) + + kek_at_alice_for_charlie = alg.aeskw.prepare_key(dk_at_alice_for_charlie) + wrapped_for_charlie = alg.aeskw.wrap_cek(cek, kek_at_alice_for_charlie) + ek_for_charlie = wrapped_for_charlie["ek"] + assert ( + urlsafe_b64encode(ek_for_charlie) + == b"56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE" + ) + + # Derived key computation at Bob for Alice + + # Step-by-step methods verification + _shared_key_e_at_bob_for_alice = bob_static_key.exchange_shared_key( + alice_ephemeral_pubkey + ) + assert _shared_key_e_at_bob_for_alice == _shared_key_e_at_alice_for_bob + + _shared_key_s_at_bob_for_alice = bob_static_key.exchange_shared_key( + alice_static_pubkey + ) + assert _shared_key_s_at_bob_for_alice == _shared_key_s_at_alice_for_bob + + _shared_key_at_bob_for_alice = alg.compute_shared_key( + _shared_key_e_at_bob_for_alice, _shared_key_s_at_bob_for_alice + ) + assert _shared_key_at_bob_for_alice == _shared_key_at_alice_for_bob + + _fixed_info_at_bob_for_alice = alg.compute_fixed_info(protected, alg.key_size, tag) + assert _fixed_info_at_bob_for_alice == _fixed_info_at_alice_for_bob + + _dk_at_bob_for_alice = alg.compute_derived_key( + _shared_key_at_bob_for_alice, _fixed_info_at_bob_for_alice, alg.key_size + ) + assert _dk_at_bob_for_alice == _dk_at_alice_for_bob + + # All-in-one method verification + dk_at_bob_for_alice = alg.deliver_at_recipient( + bob_static_key, + alice_static_pubkey, + alice_ephemeral_pubkey, + protected, + alg.key_size, + tag, + ) + assert dk_at_bob_for_alice == dk_at_alice_for_bob + + kek_at_bob_for_alice = alg.aeskw.prepare_key(dk_at_bob_for_alice) + cek_unwrapped_by_bob = alg.aeskw.unwrap( + enc, ek_for_bob, protected, kek_at_bob_for_alice + ) + assert cek_unwrapped_by_bob == cek + + payload_decrypted_by_bob = enc.decrypt( + ciphertext, aad, iv, tag, cek_unwrapped_by_bob + ) + assert payload_decrypted_by_bob == payload + + # Derived key computation at Charlie for Alice + + # Step-by-step methods verification + _shared_key_e_at_charlie_for_alice = charlie_static_key.exchange_shared_key( + alice_ephemeral_pubkey + ) + assert _shared_key_e_at_charlie_for_alice == _shared_key_e_at_alice_for_charlie + + _shared_key_s_at_charlie_for_alice = charlie_static_key.exchange_shared_key( + alice_static_pubkey + ) + assert _shared_key_s_at_charlie_for_alice == _shared_key_s_at_alice_for_charlie + + _shared_key_at_charlie_for_alice = alg.compute_shared_key( + _shared_key_e_at_charlie_for_alice, _shared_key_s_at_charlie_for_alice + ) + assert _shared_key_at_charlie_for_alice == _shared_key_at_alice_for_charlie + + _fixed_info_at_charlie_for_alice = alg.compute_fixed_info( + protected, alg.key_size, tag + ) + assert _fixed_info_at_charlie_for_alice == _fixed_info_at_alice_for_charlie + + _dk_at_charlie_for_alice = alg.compute_derived_key( + _shared_key_at_charlie_for_alice, + _fixed_info_at_charlie_for_alice, + alg.key_size, + ) + assert _dk_at_charlie_for_alice == _dk_at_alice_for_charlie + + # All-in-one method verification + dk_at_charlie_for_alice = alg.deliver_at_recipient( + charlie_static_key, + alice_static_pubkey, + alice_ephemeral_pubkey, + protected, + alg.key_size, + tag, + ) + assert dk_at_charlie_for_alice == dk_at_alice_for_charlie + + kek_at_charlie_for_alice = alg.aeskw.prepare_key(dk_at_charlie_for_alice) + cek_unwrapped_by_charlie = alg.aeskw.unwrap( + enc, ek_for_charlie, protected, kek_at_charlie_for_alice + ) + assert cek_unwrapped_by_charlie == cek + + payload_decrypted_by_charlie = enc.decrypt( + ciphertext, aad, iv, tag, cek_unwrapped_by_charlie + ) + assert payload_decrypted_by_charlie == payload + + +def test_ecdh_1pu_jwe_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": "ECDH-1PU", "enc": enc} + data = jwe.serialize_compact(protected, b"hello", bob_key, sender_key=alice_key) + rv = jwe.deserialize_compact(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_ecdh_1pu_jwe_json_serialization_single_recipient_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-1PU", "enc": "A128GCM"} + header_obj = {"protected": protected} + data = jwe.serialize_json(header_obj, b"hello", bob_key, sender_key=alice_key) + rv = jwe.deserialize_json(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_ecdh_1pu_jwe_in_key_agreement_with_key_wrapping_mode(): + jwe = JsonWebEncryption() + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + for alg in [ + "ECDH-1PU+A128KW", + "ECDH-1PU+A192KW", + "ECDH-1PU+A256KW", + ]: + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + ]: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact( + protected, b"hello", bob_key, sender_key=alice_key + ) + rv = jwe.deserialize_compact(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_ecdh_1pu_jwe_with_compact_serialization_ignores_kid_provided_separately_on_decryption(): + jwe = JsonWebEncryption() + + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + + bob_kid = "Bob's key" + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + for alg in [ + "ECDH-1PU+A128KW", + "ECDH-1PU+A192KW", + "ECDH-1PU+A256KW", + ]: + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + ]: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact( + protected, b"hello", bob_key, sender_key=alice_key + ) + rv = jwe.deserialize_compact(data, (bob_kid, bob_key), sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_ecdh_1pu_jwe_with_okp_keys_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": "ECDH-1PU", "enc": enc} + data = jwe.serialize_compact(protected, b"hello", bob_key, sender_key=alice_key) + rv = jwe.deserialize_compact(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_ecdh_1pu_jwe_with_okp_keys_in_key_agreement_with_key_wrapping_mode(): + jwe = JsonWebEncryption() + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + + for alg in [ + "ECDH-1PU+A128KW", + "ECDH-1PU+A192KW", + "ECDH-1PU+A256KW", + ]: + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + ]: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact( + protected, b"hello", bob_key, sender_key=alice_key + ) + rv = jwe.deserialize_compact(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_ecdh_1pu_encryption_with_json_serialization(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json( + header_obj, payload, [bob_key, charlie_key], sender_key=alice_key + ) + + assert data.keys() == { + "protected", + "unprotected", + "recipients", + "aad", + "iv", + "ciphertext", + "tag", + } + + decoded_protected = json_loads( + urlsafe_b64decode(to_bytes(data["protected"])).decode("utf-8") + ) + assert decoded_protected.keys() == protected.keys() | {"epk"} + assert { + k: decoded_protected[k] for k in decoded_protected.keys() - {"epk"} + } == protected + + assert data["unprotected"] == unprotected + + assert len(data["recipients"]) == len(recipients) + for i in range(len(data["recipients"])): + assert data["recipients"][i].keys() == {"header", "encrypted_key"} + assert data["recipients"][i]["header"] == recipients[i]["header"] + + assert urlsafe_b64decode(to_bytes(data["aad"])) == jwe_aad + + iv = urlsafe_b64decode(to_bytes(data["iv"])) + ciphertext = urlsafe_b64decode(to_bytes(data["ciphertext"])) + tag = urlsafe_b64decode(to_bytes(data["tag"])) + + alg = JsonWebEncryption.ALG_REGISTRY[protected["alg"]] + enc = JsonWebEncryption.ENC_REGISTRY[protected["enc"]] + + aad = to_bytes(data["protected"]) + b"." + to_bytes(data["aad"]) + aad = to_bytes(aad, "ascii") + + ek_for_bob = urlsafe_b64decode(to_bytes(data["recipients"][0]["encrypted_key"])) + header_for_bob = JWEHeader( + decoded_protected, data["unprotected"], data["recipients"][0]["header"] + ) + cek_at_bob = alg.unwrap( + enc, ek_for_bob, header_for_bob, bob_key, sender_key=alice_key, tag=tag + ) + payload_at_bob = enc.decrypt(ciphertext, aad, iv, tag, cek_at_bob) + + assert payload_at_bob == payload + + ek_for_charlie = urlsafe_b64decode(to_bytes(data["recipients"][1]["encrypted_key"])) + header_for_charlie = JWEHeader( + decoded_protected, data["unprotected"], data["recipients"][1]["header"] + ) + cek_at_charlie = alg.unwrap( + enc, + ek_for_charlie, + header_for_charlie, + charlie_key, + sender_key=alice_key, + tag=tag, + ) + payload_at_charlie = enc.decrypt(ciphertext, aad, iv, tag, cek_at_charlie) + + assert cek_at_charlie == cek_at_bob + assert payload_at_charlie == payload + + +def test_ecdh_1pu_decryption_with_json_serialization(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + data = { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" + + "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" + + "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" + + "RnFVQUZhMzlkeUJjIn19", + "unprotected": {"jku": "https://provider.test/jwks"}, + "recipients": [ + { + "header": {"kid": "bob-key-2"}, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" + + "eU1cSl55cQ0hGezJu2N9IY0QN", + }, + { + "header": {"kid": "2021-05-06"}, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8" + + "fe4z3PQ2YH2afvjQ28aiCTWFE", + }, + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ", + } + + rv_at_bob = jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + assert rv_at_bob.keys() == {"header", "payload"} + + assert rv_at_bob["header"].keys() == {"protected", "unprotected", "recipients"} + + assert rv_at_bob["header"]["protected"] == { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + }, + } + + assert rv_at_bob["header"]["unprotected"] == {"jku": "https://provider.test/jwks"} + + assert rv_at_bob["header"]["recipients"] == [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + assert rv_at_bob["payload"] == b"Three is a magic number." + + rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key) + + assert rv_at_charlie.keys() == {"header", "payload"} + + assert rv_at_charlie["header"].keys() == { + "protected", + "unprotected", + "recipients", + } + + assert rv_at_charlie["header"]["protected"] == { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + }, + } + + assert rv_at_charlie["header"]["unprotected"] == { + "jku": "https://provider.test/jwks" + } + + assert rv_at_charlie["header"]["recipients"] == [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + assert rv_at_charlie["payload"] == b"Three is a magic number." + + +def test_ecdh_1pu_jwe_with_json_serialization_when_kid_is_not_specified(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json( + header_obj, payload, [bob_key, charlie_key], sender_key=alice_key + ) + + rv_at_bob = jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + assert rv_at_bob["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_bob["header"]["protected"][k] + for k in rv_at_bob["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_bob["header"]["unprotected"] == unprotected + assert rv_at_bob["header"]["recipients"] == recipients + assert rv_at_bob["header"]["aad"] == jwe_aad + assert rv_at_bob["payload"] == payload + + rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key) + + assert rv_at_charlie["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_charlie["header"]["protected"][k] + for k in rv_at_charlie["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_charlie["header"]["unprotected"] == unprotected + assert rv_at_charlie["header"]["recipients"] == recipients + assert rv_at_charlie["header"]["aad"] == jwe_aad + assert rv_at_charlie["payload"] == payload + + +def test_ecdh_1pu_jwe_with_json_serialization_when_kid_is_specified(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "alice-key", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "bob-key-2", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "2021-05-06", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json( + header_obj, payload, [bob_key, charlie_key], sender_key=alice_key + ) + + rv_at_bob = jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + assert rv_at_bob["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_bob["header"]["protected"][k] + for k in rv_at_bob["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_bob["header"]["unprotected"] == unprotected + assert rv_at_bob["header"]["recipients"] == recipients + assert rv_at_bob["header"]["aad"] == jwe_aad + assert rv_at_bob["payload"] == payload + + rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key) + + assert rv_at_charlie["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_charlie["header"]["protected"][k] + for k in rv_at_charlie["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_charlie["header"]["unprotected"] == unprotected + assert rv_at_charlie["header"]["recipients"] == recipients + assert rv_at_charlie["header"]["aad"] == jwe_aad + assert rv_at_charlie["payload"] == payload + + +def test_ecdh_1pu_jwe_with_json_serialization_when_kid_is_provided_separately_on_decryption(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "WjKgJV7VRw3hmgU6--4v15c0Aewbcvat1BsRFTIqa5Q", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + + bob_kid = "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A" + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + + charlie_kid = "did:example:123#_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw" + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [ + { + "header": { + "kid": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A" + } + }, + { + "header": { + "kid": "did:example:123#_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw" + } + }, + ] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json( + header_obj, payload, [bob_key, charlie_key], sender_key=alice_key + ) + + rv_at_bob = jwe.deserialize_json(data, (bob_kid, bob_key), sender_key=alice_key) + + assert rv_at_bob["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_bob["header"]["protected"][k] + for k in rv_at_bob["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_bob["header"]["unprotected"] == unprotected + assert rv_at_bob["header"]["recipients"] == recipients + assert rv_at_bob["header"]["aad"] == jwe_aad + assert rv_at_bob["payload"] == payload + + rv_at_charlie = jwe.deserialize_json( + data, (charlie_kid, charlie_key), sender_key=alice_key + ) + + assert rv_at_charlie["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_charlie["header"]["protected"][k] + for k in rv_at_charlie["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_charlie["header"]["unprotected"] == unprotected + assert rv_at_charlie["header"]["recipients"] == recipients + assert rv_at_charlie["header"]["aad"] == jwe_aad + assert rv_at_charlie["payload"] == payload + + +def test_ecdh_1pu_jwe_with_json_serialization_for_single_recipient(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + + protected = { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9i", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [{"header": {"kid": "bob-key-2"}}] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json(header_obj, payload, bob_key, sender_key=alice_key) + + rv = jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + assert rv["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv["header"]["protected"][k] + for k in rv["header"]["protected"].keys() - {"epk"} + } == protected + assert rv["header"]["unprotected"] == unprotected + assert rv["header"]["recipients"] == recipients + assert rv["header"]["aad"] == jwe_aad + assert rv["payload"] == payload + + +def test_ecdh_1pu_encryption_fails_json_serialization_multiple_recipients_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + charlie_key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-1PU", "enc": "A128GCM"} + header_obj = {"protected": protected} + with pytest.raises(InvalidAlgorithmForMultipleRecipientsMode): + jwe.serialize_json( + header_obj, + b"hello", + [bob_key, charlie_key], + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_fails_if_not_aes_cbc_hmac_sha2_enc_is_used_with_kw(): + jwe = JsonWebEncryption() + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + } + + for alg in [ + "ECDH-1PU+A128KW", + "ECDH-1PU+A192KW", + "ECDH-1PU+A256KW", + ]: + for enc in [ + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": alg, "enc": enc} + with pytest.raises( + InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError + ): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_with_public_sender_key_fails(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_ecdh_1pu_decryption_with_public_recipient_key_fails(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + } + data = jwe.serialize_compact(protected, b"hello", bob_key, sender_key=alice_key) + with pytest.raises(ValueError): + jwe.deserialize_compact(data, bob_key, sender_key=alice_key) + + +def test_ecdh_1pu_encryption_fails_if_key_types_are_different(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + + alice_key = ECKey.generate_key("P-256", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=False) + with pytest.raises(TypeError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = ECKey.generate_key("P-256", is_private=False) + with pytest.raises(TypeError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_fails_if_keys_curves_are_different(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + + alice_key = ECKey.generate_key("P-256", is_private=True) + bob_key = ECKey.generate_key("secp256k1", is_private=False) + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + alice_key = ECKey.generate_key("P-384", is_private=True) + bob_key = ECKey.generate_key("P-521", is_private=False) + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X448", is_private=False) + with pytest.raises(TypeError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_fails_if_key_points_are_not_actually_on_same_curve(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "aDHtGkIYyhR5geqfMaFL0T9cG4JEMI8nyMFJA7gRUDs", + "y": "AjGN5_f-aCt4vYg74my6n1ALIq746nlc_httIgcBSYY", + "d": "Sim3EIzXsWaWu9QW8yKVHwxBM5CTlnrVU_Eq-y_KRQA", + } # the point is indeed on P-256 curve + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "5ZFnZbs_BtLBIZxwt5hS7SBDtI2a-dJ871dJ8ZnxZ6c", + "y": "K0srqSkbo1Yeckr0YoQA8r_rOz0ZUStiv3mc1qn46pg", + } # the point is not on P-256 curve but is actually on secp256k1 curve + + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + alice_key = { + "kty": "EC", + "crv": "P-521", + "x": "1JDMOjnMgASo01PVHRcyCDtE6CLgKuwXLXLbdLGxpdubLuHYBa0KAepyimnxCWsX", + "y": "w7BSC8Xb3XgMMfE7IFCJpoOmx1Sf3T3_3OZ4CrF6_iCFAw4VOdFYR42OnbKMFG--", + "d": "lCkpFBaVwHzfHtkJEV3PzxefObOPnMgUjNZSLryqC5AkERgXT3-DZLEi6eBzq5gk", + } # the point is not on P-521 curve but is actually on P-384 curve + bob_key = { + "kty": "EC", + "crv": "P-521", + "x": "Cd6rinJdgS4WJj6iaNyXiVhpMbhZLmPykmrnFhIad04B3ulf5pURb5v9mx21c_Cv8Q1RBOptwleLg5Qjq2J1qa4", + "y": "hXo9p1EjW6W4opAQdmfNgyxztkNxYwn9L4FVTLX51KNEsW0aqueLm96adRmf0HoGIbNhIdcIlXOKlRUHqgunDkM", + } # the point is indeed on P-521 curve + + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "TAB1oIsjPob3guKwTEeQsAsupSRPdXdxHhnV8JrVJTA", + "d": "kO2LzPr4vLg_Hn-7_MDq66hJZgvTIkzDG4p6nCsgNHk", + } + ) # the point is indeed on X25519 curve + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "lVHcPx4R9bExaoxXZY9tAq7SNW9pJKCoVQxURLtkAs3Dg5ZRxcjhf0JUyg2lod5OGDptJ7wowwY", + } + ) # the point is not on X25519 curve but is actually on X448 curve + + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X448", + "x": "TAB1oIsjPob3guKwTEeQsAsupSRPdXdxHhnV8JrVJTA", + "d": "kO2LzPr4vLg_Hn-7_MDq66hJZgvTIkzDG4p6nCsgNHk", + } + ) # the point is not on X448 curve but is actually on X25519 curve + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X448", + "x": "lVHcPx4R9bExaoxXZY9tAq7SNW9pJKCoVQxURLtkAs3Dg5ZRxcjhf0JUyg2lod5OGDptJ7wowwY", + } + ) # the point is indeed on X448 curve + + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_fails_if_keys_curve_is_inappropriate(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + + alice_key = OKPKey.generate_key( + "Ed25519", is_private=True + ) # use Ed25519 instead of X25519 + bob_key = OKPKey.generate_key( + "Ed25519", is_private=False + ) # use Ed25519 instead of X25519 + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_for_multiple_recipients_fails_if_key_types_are_different(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + header_obj = {"protected": protected} + + alice_key = ECKey.generate_key("P-256", is_private=True) + bob_key = ECKey.generate_key("P-256", is_private=False) + charlie_key = OKPKey.generate_key("X25519", is_private=False) + + with pytest.raises(TypeError): + jwe.serialize_json( + header_obj, + b"hello", + [bob_key, charlie_key], + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_for_multiple_recipients_fails_if_keys_curves_are_different(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + header_obj = {"protected": protected} + + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X448", is_private=False) + charlie_key = OKPKey.generate_key("X25519", is_private=False) + + with pytest.raises(TypeError): + jwe.serialize_json( + header_obj, + b"hello", + [bob_key, charlie_key], + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_for_multiple_recipients_fails_if_key_points_are_not_actually_on_same_curve(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + header_obj = {"protected": protected} + + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "aDHtGkIYyhR5geqfMaFL0T9cG4JEMI8nyMFJA7gRUDs", + "y": "AjGN5_f-aCt4vYg74my6n1ALIq746nlc_httIgcBSYY", + "d": "Sim3EIzXsWaWu9QW8yKVHwxBM5CTlnrVU_Eq-y_KRQA", + } # the point is indeed on P-256 curve + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "HgF88mm6yw4gjG7yG6Sqz66pHnpZcyx7c842BQghYuc", + "y": "KZ1ywvTOYnpNb4Gepa5eSgfEOb5gj5hCaCFIrTFuI2o", + } # the point is indeed on P-256 curve + charlie_key = { + "kty": "EC", + "crv": "P-256", + "x": "5ZFnZbs_BtLBIZxwt5hS7SBDtI2a-dJ871dJ8ZnxZ6c", + "y": "K0srqSkbo1Yeckr0YoQA8r_rOz0ZUStiv3mc1qn46pg", + } # the point is not on P-256 curve but is actually on secp256k1 curve + + with pytest.raises(ValueError): + jwe.serialize_json( + header_obj, + b"hello", + [bob_key, charlie_key], + sender_key=alice_key, + ) + + +def test_ecdh_1pu_encryption_for_multiple_recipients_fails_if_keys_curve_is_inappropriate(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + header_obj = {"protected": protected} + + alice_key = OKPKey.generate_key( + "Ed25519", is_private=True + ) # use Ed25519 instead of X25519 + bob_key = OKPKey.generate_key( + "Ed25519", is_private=False + ) # use Ed25519 instead of X25519 + charlie_key = OKPKey.generate_key( + "Ed25519", is_private=False + ) # use Ed25519 instead of X25519 + + with pytest.raises(ValueError): + jwe.serialize_json( + header_obj, + b"hello", + [bob_key, charlie_key], + sender_key=alice_key, + ) + + +def test_ecdh_1pu_decryption_fails_if_key_matches_to_no_recipient(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9i", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [{"header": {"kid": "bob-key-2"}}] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json(header_obj, payload, bob_key, sender_key=alice_key) + + with pytest.raises(InvalidUnwrap): + jwe.deserialize_json(data, charlie_key, sender_key=alice_key) diff --git a/tests/jose/test_jwe.py b/tests/jose/test_jwe.py new file mode 100644 index 000000000..a59c9ad24 --- /dev/null +++ b/tests/jose/test_jwe.py @@ -0,0 +1,1544 @@ +import json +import os + +import pytest +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.keywrap import InvalidUnwrap + +from authlib.common.encoding import json_b64encode +from authlib.common.encoding import to_bytes +from authlib.common.encoding import to_unicode +from authlib.common.encoding import urlsafe_b64encode +from authlib.jose import JsonWebEncryption +from authlib.jose import OctKey +from authlib.jose import OKPKey +from authlib.jose import errors +from authlib.jose.drafts import register_jwe_draft +from authlib.jose.errors import DecodeError +from authlib.jose.errors import InvalidAlgorithmForMultipleRecipientsMode +from authlib.jose.errors import InvalidHeaderParameterNameError +from authlib.jose.util import extract_header +from tests.util import read_file_path + +register_jwe_draft(JsonWebEncryption) + + +def test_not_enough_segments(): + s = "a.b.c" + jwe = JsonWebEncryption() + with pytest.raises(errors.DecodeError): + jwe.deserialize_compact(s, None) + + +def test_invalid_header(): + jwe = JsonWebEncryption() + public_key = read_file_path("rsa_public.pem") + with pytest.raises(errors.MissingAlgorithmError): + jwe.serialize_compact({}, "a", public_key) + with pytest.raises(errors.UnsupportedAlgorithmError): + jwe.serialize_compact( + {"alg": "invalid"}, + "a", + public_key, + ) + with pytest.raises(errors.MissingEncryptionAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA-OAEP"}, + "a", + public_key, + ) + with pytest.raises(errors.UnsupportedEncryptionAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "invalid"}, + "a", + public_key, + ) + with pytest.raises(errors.UnsupportedCompressionAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A256GCM", "zip": "invalid"}, + "a", + public_key, + ) + + +def test_not_supported_alg(): + public_key = read_file_path("rsa_public.pem") + private_key = read_file_path("rsa_private.pem") + + jwe = JsonWebEncryption() + s = jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A256GCM"}, "hello", public_key + ) + + jwe = JsonWebEncryption(algorithms=["RSA1_5", "A256GCM"]) + with pytest.raises(errors.UnsupportedAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A256GCM"}, + "hello", + public_key, + ) + with pytest.raises(errors.UnsupportedCompressionAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA1_5", "enc": "A256GCM", "zip": "DEF"}, + "hello", + public_key, + ) + with pytest.raises(errors.UnsupportedAlgorithmError): + jwe.deserialize_compact( + s, + private_key, + ) + + jwe = JsonWebEncryption(algorithms=["RSA-OAEP", "A192GCM"]) + with pytest.raises(errors.UnsupportedEncryptionAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A256GCM"}, + "hello", + public_key, + ) + with pytest.raises(errors.UnsupportedCompressionAlgorithmError): + jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A192GCM", "zip": "DEF"}, + "hello", + public_key, + ) + with pytest.raises(errors.UnsupportedEncryptionAlgorithmError): + jwe.deserialize_compact( + s, + private_key, + ) + + +def test_inappropriate_sender_key_for_serialize_compact(): + jwe = JsonWebEncryption() + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", bob_key) + + protected = {"alg": "ECDH-ES", "enc": "A256GCM"} + with pytest.raises(ValueError): + jwe.serialize_compact( + protected, + b"hello", + bob_key, + sender_key=alice_key, + ) + + +def test_inappropriate_sender_key_for_deserialize_compact(): + jwe = JsonWebEncryption() + alice_key = { + "kty": "EC", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE", + "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg", + } + bob_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + protected = {"alg": "ECDH-1PU", "enc": "A256GCM"} + data = jwe.serialize_compact(protected, b"hello", bob_key, sender_key=alice_key) + with pytest.raises(ValueError): + jwe.deserialize_compact(data, bob_key) + + protected = {"alg": "ECDH-ES", "enc": "A256GCM"} + data = jwe.serialize_compact(protected, b"hello", bob_key) + with pytest.raises(ValueError): + jwe.deserialize_compact(data, bob_key, sender_key=alice_key) + + +def test_compact_rsa(): + jwe = JsonWebEncryption() + s = jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A256GCM"}, + "hello", + read_file_path("rsa_public.pem"), + ) + data = jwe.deserialize_compact(s, read_file_path("rsa_private.pem")) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "RSA-OAEP" + + +def test_with_zip_header(): + jwe = JsonWebEncryption() + s = jwe.serialize_compact( + {"alg": "RSA-OAEP", "enc": "A128CBC-HS256", "zip": "DEF"}, + "hello", + read_file_path("rsa_public.pem"), + ) + data = jwe.deserialize_compact(s, read_file_path("rsa_private.pem")) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "RSA-OAEP" + + +def test_aes_jwe(): + jwe = JsonWebEncryption() + sizes = [128, 192, 256] + _enc_choices = [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ] + for s in sizes: + alg = f"A{s}KW" + key = os.urandom(s // 8) + for enc in _enc_choices: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_aes_jwe_invalid_key(): + jwe = JsonWebEncryption() + protected = {"alg": "A128KW", "enc": "A128GCM"} + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", b"invalid-key") + + +def test_aes_gcm_jwe(): + jwe = JsonWebEncryption() + sizes = [128, 192, 256] + _enc_choices = [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ] + for s in sizes: + alg = f"A{s}GCMKW" + key = os.urandom(s // 8) + for enc in _enc_choices: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_aes_gcm_jwe_invalid_key(): + jwe = JsonWebEncryption() + protected = {"alg": "A128GCMKW", "enc": "A128GCM"} + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", b"invalid-key") + + +def test_serialize_compact_fails_if_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM", "foo": "bar"} + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.serialize_compact( + protected, + b"hello", + key, + ) + + +def test_serialize_compact_allows_unknown_fields_in_header_while_private_fields_not_restricted(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM", "foo": "bar"} + + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_serialize_json_fails_if_protected_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM", "foo": "bar"} + header_obj = {"protected": protected} + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.serialize_json( + header_obj, + b"hello", + key, + ) + + +def test_serialize_json_fails_if_unprotected_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + unprotected = {"foo": "bar"} + header_obj = {"protected": protected, "unprotected": unprotected} + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.serialize_json( + header_obj, + b"hello", + key, + ) + + +def test_serialize_json_fails_if_recipient_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + recipients = [{"header": {"foo": "bar"}}] + header_obj = {"protected": protected, "recipients": recipients} + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.serialize_json( + header_obj, + b"hello", + key, + ) + + +def test_serialize_json_allows_unknown_fields_in_headers_while_private_fields_not_restricted(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM", "foo1": "bar1"} + unprotected = {"foo2": "bar2"} + recipients = [{"header": {"foo3": "bar3"}}] + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + } + + data = jwe.serialize_json(header_obj, b"hello", key) + rv = jwe.deserialize_json(data, key) + assert rv["payload"] == b"hello" + + +def test_serialize_json_ignores_additional_members_in_recipients_elements(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_deserialize_json_fails_if_protected_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + header_obj = {"protected": protected} + + data = jwe.serialize_json(header_obj, b"hello", key) + + decoded_protected = extract_header(to_bytes(data["protected"]), DecodeError) + decoded_protected["foo"] = "bar" + data["protected"] = to_unicode(json_b64encode(decoded_protected)) + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.deserialize_json(data, key) + + +def test_deserialize_json_fails_if_unprotected_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + header_obj = {"protected": protected} + + data = jwe.serialize_json(header_obj, b"hello", key) + + data["unprotected"] = {"foo": "bar"} + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.deserialize_json(data, key) + + +def test_deserialize_json_fails_if_recipient_header_contains_unknown_field_while_private_fields_restricted(): + jwe = JsonWebEncryption(private_headers=set()) + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + header_obj = {"protected": protected} + + data = jwe.serialize_json(header_obj, b"hello", key) + + data["recipients"][0]["header"] = {"foo": "bar"} + + with pytest.raises(InvalidHeaderParameterNameError): + jwe.deserialize_json(data, key) + + +def test_deserialize_json_allows_unknown_fields_in_headers_while_private_fields_not_restricted(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + header_obj = {"protected": protected} + + data = jwe.serialize_json(header_obj, b"hello", key) + + data["unprotected"] = {"foo1": "bar1"} + data["recipients"][0]["header"] = {"foo2": "bar2"} + + rv = jwe.deserialize_json(data, key) + assert rv["payload"] == b"hello" + + +def test_deserialize_json_ignores_additional_members_in_recipients_elements(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + header_obj = {"protected": protected} + + data = jwe.serialize_json(header_obj, b"hello", key) + + data["recipients"][0]["foo"] = "bar" + + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_deserialize_json_ignores_additional_members_in_jwe_message(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"} + header_obj = {"protected": protected} + + data = jwe.serialize_json(header_obj, b"hello", key) + + data["foo"] = "bar" + + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_ecdh_es_key_agreement_computation(): + # https://tools.ietf.org/html/rfc7518#appendix-C + alice_ephemeral_key = { + "kty": "EC", + "crv": "P-256", + "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", + "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", + "d": "0_NxaRPUMQoAJt50Gz8YiTr8gRTwyEaCumd-MToTmIo", + } + bob_static_key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + headers = { + "alg": "ECDH-ES", + "enc": "A128GCM", + "apu": "QWxpY2U", + "apv": "Qm9i", + "epk": { + "kty": "EC", + "crv": "P-256", + "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0", + "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps", + }, + } + + alg = JsonWebEncryption.ALG_REGISTRY["ECDH-ES"] + enc = JsonWebEncryption.ENC_REGISTRY["A128GCM"] + + alice_ephemeral_key = alg.prepare_key(alice_ephemeral_key) + bob_static_key = alg.prepare_key(bob_static_key) + + alice_ephemeral_pubkey = alice_ephemeral_key.get_op_key("wrapKey") + bob_static_pubkey = bob_static_key.get_op_key("wrapKey") + + # Derived key computation at Alice + + # Step-by-step methods verification + _shared_key_at_alice = alice_ephemeral_key.exchange_shared_key(bob_static_pubkey) + assert _shared_key_at_alice == bytes( + [ + 158, + 86, + 217, + 29, + 129, + 113, + 53, + 211, + 114, + 131, + 66, + 131, + 191, + 132, + 38, + 156, + 251, + 49, + 110, + 163, + 218, + 128, + 106, + 72, + 246, + 218, + 167, + 121, + 140, + 254, + 144, + 196, + ] + ) + + _fixed_info_at_alice = alg.compute_fixed_info(headers, enc.key_size) + assert _fixed_info_at_alice == bytes( + [ + 0, + 0, + 0, + 7, + 65, + 49, + 50, + 56, + 71, + 67, + 77, + 0, + 0, + 0, + 5, + 65, + 108, + 105, + 99, + 101, + 0, + 0, + 0, + 3, + 66, + 111, + 98, + 0, + 0, + 0, + 128, + ] + ) + + _dk_at_alice = alg.compute_derived_key( + _shared_key_at_alice, _fixed_info_at_alice, enc.key_size + ) + assert _dk_at_alice == bytes( + [86, 170, 141, 234, 248, 35, 109, 32, 92, 34, 40, 205, 113, 167, 16, 26] + ) + assert urlsafe_b64encode(_dk_at_alice) == b"VqqN6vgjbSBcIijNcacQGg" + + # All-in-one method verification + dk_at_alice = alg.deliver( + alice_ephemeral_key, bob_static_pubkey, headers, enc.key_size + ) + assert dk_at_alice == bytes( + [86, 170, 141, 234, 248, 35, 109, 32, 92, 34, 40, 205, 113, 167, 16, 26] + ) + assert urlsafe_b64encode(dk_at_alice) == b"VqqN6vgjbSBcIijNcacQGg" + + # Derived key computation at Bob + + # Step-by-step methods verification + _shared_key_at_bob = bob_static_key.exchange_shared_key(alice_ephemeral_pubkey) + assert _shared_key_at_bob == _shared_key_at_alice + + _fixed_info_at_bob = alg.compute_fixed_info(headers, enc.key_size) + assert _fixed_info_at_bob == _fixed_info_at_alice + + _dk_at_bob = alg.compute_derived_key( + _shared_key_at_bob, _fixed_info_at_bob, enc.key_size + ) + assert _dk_at_bob == _dk_at_alice + + # All-in-one method verification + dk_at_bob = alg.deliver( + bob_static_key, alice_ephemeral_pubkey, headers, enc.key_size + ) + assert dk_at_bob == dk_at_alice + + +def test_ecdh_es_jwe_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": "ECDH-ES", "enc": enc} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_ecdh_es_jwe_json_serialization_single_recipient_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES", "enc": "A128GCM"} + header_obj = {"protected": protected} + data = jwe.serialize_json(header_obj, b"hello", key) + rv = jwe.deserialize_json(data, key) + assert rv["payload"] == b"hello" + + +def test_ecdh_es_jwe_in_key_agreement_with_key_wrapping_mode(): + jwe = JsonWebEncryption() + key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw", + } + + for alg in [ + "ECDH-ES+A128KW", + "ECDH-ES+A192KW", + "ECDH-ES+A256KW", + ]: + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_ecdh_es_jwe_with_okp_key_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": "ECDH-ES", "enc": enc} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_ecdh_es_jwe_with_okp_key_in_key_agreement_with_key_wrapping_mode(): + jwe = JsonWebEncryption() + key = OKPKey.generate_key("X25519", is_private=True) + + for alg in [ + "ECDH-ES+A128KW", + "ECDH-ES+A192KW", + "ECDH-ES+A256KW", + ]: + for enc in [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ]: + protected = {"alg": alg, "enc": enc} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + +def test_ecdh_es_jwe_with_json_serialization_when_kid_is_not_specified(): + jwe = JsonWebEncryption() + + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-ES+A256KW", + "enc": "A256GCM", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json(header_obj, payload, [bob_key, charlie_key]) + + rv_at_bob = jwe.deserialize_json(data, bob_key) + + assert rv_at_bob["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_bob["header"]["protected"][k] + for k in rv_at_bob["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_bob["header"]["unprotected"] == unprotected + assert rv_at_bob["header"]["recipients"] == recipients + assert rv_at_bob["header"]["aad"] == jwe_aad + assert rv_at_bob["payload"] == payload + + rv_at_charlie = jwe.deserialize_json(data, charlie_key) + + assert rv_at_charlie["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_charlie["header"]["protected"][k] + for k in rv_at_charlie["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_charlie["header"]["unprotected"] == unprotected + assert rv_at_charlie["header"]["recipients"] == recipients + assert rv_at_charlie["header"]["aad"] == jwe_aad + assert rv_at_charlie["payload"] == payload + + +def test_ecdh_es_jwe_with_json_serialization_when_kid_is_specified(): + jwe = JsonWebEncryption() + + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "bob-key-2", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "kid": "2021-05-06", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-ES+A256KW", + "enc": "A256GCM", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json(header_obj, payload, [bob_key, charlie_key]) + + rv_at_bob = jwe.deserialize_json(data, bob_key) + + assert rv_at_bob["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_bob["header"]["protected"][k] + for k in rv_at_bob["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_bob["header"]["unprotected"] == unprotected + assert rv_at_bob["header"]["recipients"] == recipients + assert rv_at_bob["header"]["aad"] == jwe_aad + assert rv_at_bob["payload"] == payload + + rv_at_charlie = jwe.deserialize_json(data, charlie_key) + + assert rv_at_charlie["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv_at_charlie["header"]["protected"][k] + for k in rv_at_charlie["header"]["protected"].keys() - {"epk"} + } == protected + assert rv_at_charlie["header"]["unprotected"] == unprotected + assert rv_at_charlie["header"]["recipients"] == recipients + assert rv_at_charlie["header"]["aad"] == jwe_aad + assert rv_at_charlie["payload"] == payload + + +def test_ecdh_es_jwe_with_json_serialization_for_single_recipient(): + jwe = JsonWebEncryption() + + key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + + protected = { + "alg": "ECDH-ES+A256KW", + "enc": "A256GCM", + "apu": "QWxpY2U", + "apv": "Qm9i", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [{"header": {"kid": "bob-key-2"}}] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json(header_obj, payload, key) + + rv = jwe.deserialize_json(data, key) + + assert rv["header"]["protected"].keys() == protected.keys() | {"epk"} + assert { + k: rv["header"]["protected"][k] + for k in rv["header"]["protected"].keys() - {"epk"} + } == protected + assert rv["header"]["unprotected"] == unprotected + assert rv["header"]["recipients"] == recipients + assert rv["header"]["aad"] == jwe_aad + assert rv["payload"] == payload + + +def test_ecdh_es_encryption_fails_json_serialization_multiple_recipients_in_direct_key_agreement_mode(): + jwe = JsonWebEncryption() + bob_key = OKPKey.generate_key("X25519", is_private=True) + charlie_key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-ES", "enc": "A128GCM"} + header_obj = {"protected": protected} + with pytest.raises(InvalidAlgorithmForMultipleRecipientsMode): + jwe.serialize_json( + header_obj, + b"hello", + [bob_key, charlie_key], + ) + + +def test_ecdh_es_decryption_with_public_key_fails(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-ES", "enc": "A128GCM"} + + key = { + "kty": "EC", + "crv": "P-256", + "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ", + "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck", + } + data = jwe.serialize_compact(protected, b"hello", key) + with pytest.raises(ValueError): + jwe.deserialize_compact(data, key) + + +def test_ecdh_es_encryption_fails_if_key_curve_is_inappropriate(): + jwe = JsonWebEncryption() + protected = {"alg": "ECDH-ES", "enc": "A128GCM"} + + key = OKPKey.generate_key("Ed25519", is_private=False) + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", key) + + +def test_ecdh_es_decryption_fails_if_key_matches_to_no_recipient(): + jwe = JsonWebEncryption() + + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + protected = { + "alg": "ECDH-ES+A256KW", + "enc": "A256GCM", + "apu": "QWxpY2U", + "apv": "Qm9i", + } + + unprotected = {"jku": "https://provider.test/jwks"} + + recipients = [{"header": {"kid": "bob-key-2"}}] + + jwe_aad = b"Authenticate me too." + + header_obj = { + "protected": protected, + "unprotected": unprotected, + "recipients": recipients, + "aad": jwe_aad, + } + + payload = b"Three is a magic number." + + data = jwe.serialize_json(header_obj, payload, bob_key) + + with pytest.raises(InvalidUnwrap): + jwe.deserialize_json(data, charlie_key) + + +def test_decryption_with_json_serialization_succeeds_while_encrypted_key_for_another_recipient_is_invalid(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kid": "Alice's key", + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + OKPKey.import_key( + { + "kid": "Bob's key", + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kid": "Charlie's key", + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + data = { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" + + "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" + + "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" + + "RnFVQUZhMzlkeUJjIn19", + "unprotected": {"jku": "https://provider.test/jwks"}, + "recipients": [ + { + "header": {"kid": "Bob's key"}, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" + + "eU1cSl55cQ0hGezJu2N9IY0QM", # Invalid encrypted key + }, + { + "header": {"kid": "Charlie's key"}, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8" + + "fe4z3PQ2YH2afvjQ28aiCTWFE", # Valid encrypted key + }, + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ", + } + + rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key) + + assert rv_at_charlie.keys() == {"header", "payload"} + + assert rv_at_charlie["header"].keys() == { + "protected", + "unprotected", + "recipients", + } + + assert rv_at_charlie["header"]["protected"] == { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + }, + } + + assert rv_at_charlie["header"]["unprotected"] == { + "jku": "https://provider.test/jwks" + } + + assert rv_at_charlie["header"]["recipients"] == [ + {"header": {"kid": "Bob's key"}}, + {"header": {"kid": "Charlie's key"}}, + ] + + assert rv_at_charlie["payload"] == b"Three is a magic number." + + +def test_decryption_with_json_serialization_fails_if_encrypted_key_for_this_recipient_is_invalid(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kid": "Alice's key", + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kid": "Bob's key", + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + OKPKey.import_key( + { + "kid": "Charlie's key", + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + data = { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" + + "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" + + "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" + + "RnFVQUZhMzlkeUJjIn19", + "unprotected": {"jku": "https://provider.test/jwks"}, + "recipients": [ + { + "header": {"kid": "Bob's key"}, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" + + "eU1cSl55cQ0hGezJu2N9IY0QM", # Invalid encrypted key + }, + { + "header": {"kid": "Charlie's key"}, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8" + + "fe4z3PQ2YH2afvjQ28aiCTWFE", # Valid encrypted key + }, + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ", + } + + with pytest.raises(InvalidUnwrap): + jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + +def test_dir_alg(): + jwe = JsonWebEncryption() + key = OctKey.generate_key(128, is_private=True) + protected = {"alg": "dir", "enc": "A128GCM"} + data = jwe.serialize_compact(protected, b"hello", key) + rv = jwe.deserialize_compact(data, key) + assert rv["payload"] == b"hello" + + key2 = OctKey.generate_key(256, is_private=True) + with pytest.raises(InvalidTag): + jwe.deserialize_compact(data, key2) + + with pytest.raises(ValueError): + jwe.serialize_compact(protected, b"hello", key2) + + +def test_decryption_of_message_to_multiple_recipients_by_matching_key(): + jwe = JsonWebEncryption() + + alice_public_key = OKPKey.import_key( + { + "kid": "WjKgJV7VRw3hmgU6--4v15c0Aewbcvat1BsRFTIqa5Q", + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + } + ) + + key_store = {} + + charlie_X448_key_id = "did:example:123#_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw" + charlie_X448_key = OKPKey.import_key( + { + "kid": "_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw", + "kty": "OKP", + "crv": "X448", + "x": "M-OMugy74ksznVQ-Bp6MC_-GEPSrT8yiAtminJvw0j_UxJtpNHl_hcWMSf_Pfm_ws0vVWvAfwwA", + "d": "VGZPkclj_7WbRaRMzBqxpzXIpc2xz1d3N1ay36UxdVLfKaP33hABBMpddTRv1f-hRsQUNvmlGOg", + } + ) + key_store[charlie_X448_key_id] = charlie_X448_key + + charlie_X25519_key_id = ( + "did:example:123#ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec" + ) + charlie_X25519_key = OKPKey.import_key( + { + "kid": "ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec", + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + key_store[charlie_X25519_key_id] = charlie_X25519_key + + data = """ + { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19", + "unprotected": { + "jku": "https://provider.test/jwks" + }, + "recipients": [ + { + "header": { + "kid": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A" + }, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN" + }, + { + "header": { + "kid": "did:example:123#ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec" + }, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE" + } + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ" + }""" + + parsed_data = jwe.parse_json(data) + + available_key_id = next( + recipient["header"]["kid"] + for recipient in parsed_data["recipients"] + if recipient["header"]["kid"] in key_store.keys() + ) + available_key = key_store[available_key_id] + + rv = jwe.deserialize_json( + parsed_data, (available_key_id, available_key), sender_key=alice_public_key + ) + + assert rv.keys() == {"header", "payload"} + + assert rv["header"].keys() == {"protected", "unprotected", "recipients"} + + assert rv["header"]["protected"] == { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + }, + } + + assert rv["header"]["unprotected"] == {"jku": "https://provider.test/jwks"} + + assert rv["header"]["recipients"] == [ + { + "header": { + "kid": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A" + } + }, + { + "header": { + "kid": "did:example:123#ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec" + } + }, + ] + + assert rv["payload"] == b"Three is a magic number." + + +def test_decryption_of_json_string(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + charlie_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE", + "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE", + } + ) + + data = """ + { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19", + "unprotected": { + "jku": "https://provider.test/jwks" + }, + "recipients": [ + { + "header": { + "kid": "bob-key-2" + }, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN" + }, + { + "header": { + "kid": "2021-05-06" + }, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE" + } + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ" + }""" + + rv_at_bob = jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + assert rv_at_bob.keys() == {"header", "payload"} + + assert rv_at_bob["header"].keys() == {"protected", "unprotected", "recipients"} + + assert rv_at_bob["header"]["protected"] == { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + }, + } + + assert rv_at_bob["header"]["unprotected"] == {"jku": "https://provider.test/jwks"} + + assert rv_at_bob["header"]["recipients"] == [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + assert rv_at_bob["payload"] == b"Three is a magic number." + + rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key) + + assert rv_at_charlie.keys() == {"header", "payload"} + + assert rv_at_charlie["header"].keys() == { + "protected", + "unprotected", + "recipients", + } + + assert rv_at_charlie["header"]["protected"] == { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll", + "epk": { + "kty": "OKP", + "crv": "X25519", + "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc", + }, + } + + assert rv_at_charlie["header"]["unprotected"] == { + "jku": "https://provider.test/jwks" + } + + assert rv_at_charlie["header"]["recipients"] == [ + {"header": {"kid": "bob-key-2"}}, + {"header": {"kid": "2021-05-06"}}, + ] + + assert rv_at_charlie["payload"] == b"Three is a magic number." + + +def test_parse_json(): + json_msg = """ + { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19", + "unprotected": { + "jku": "https://provider.test/jwks" + }, + "recipients": [ + { + "header": { + "kid": "bob-key-2" + }, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN" + }, + { + "header": { + "kid": "2021-05-06" + }, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE" + } + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ" + }""" + + parsed_msg = JsonWebEncryption.parse_json(json_msg) + + assert parsed_msg == { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19", + "unprotected": {"jku": "https://provider.test/jwks"}, + "recipients": [ + { + "header": {"kid": "bob-key-2"}, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN", + }, + { + "header": {"kid": "2021-05-06"}, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE", + }, + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ", + } + + +def test_parse_json_fails_if_json_msg_is_invalid(): + json_msg = """ + { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19", + "unprotected": { + "jku": "https://provider.test/jwks" + }, + "recipients": [ + { + "header": { + "kid": "bob-key-2" + , + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN" + }, + { + "header": { + "kid": "2021-05-06" + }, + "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE" + } + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw", + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ" + }""" + + with pytest.raises(DecodeError): + JsonWebEncryption.parse_json(json_msg) + + +def test_decryption_fails_if_ciphertext_is_invalid(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4", + "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU", + } + ) + bob_key = OKPKey.import_key( + { + "kty": "OKP", + "crv": "X25519", + "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw", + "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg", + } + ) + + data = { + "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" + + "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" + + "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" + + "RnFVQUZhMzlkeUJjIn19", + "unprotected": {"jku": "https://provider.test/jwks"}, + "recipients": [ + { + "header": {"kid": "bob-key-2"}, + "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" + + "eU1cSl55cQ0hGezJu2N9IY0QN", + } + ], + "iv": "AAECAwQFBgcICQoLDA0ODw", + "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFY", # invalid ciphertext + "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ", + } + + with pytest.raises(InvalidTag): + jwe.deserialize_json(data, bob_key, sender_key=alice_key) + + +def test_generic_serialize_deserialize_for_compact_serialization(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + + header_obj = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + + data = jwe.serialize(header_obj, b"hello", bob_key, sender_key=alice_key) + assert isinstance(data, bytes) + + rv = jwe.deserialize(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_generic_serialize_deserialize_for_json_serialization(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + header_obj = {"protected": protected} + + data = jwe.serialize(header_obj, b"hello", bob_key, sender_key=alice_key) + assert isinstance(data, dict) + + rv = jwe.deserialize(data, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" + + +def test_generic_deserialize_for_json_serialization_string(): + jwe = JsonWebEncryption() + + alice_key = OKPKey.generate_key("X25519", is_private=True) + bob_key = OKPKey.generate_key("X25519", is_private=True) + + protected = {"alg": "ECDH-1PU+A128KW", "enc": "A128CBC-HS256"} + header_obj = {"protected": protected} + + data = jwe.serialize(header_obj, b"hello", bob_key, sender_key=alice_key) + assert isinstance(data, dict) + + data_as_string = json.dumps(data) + + rv = jwe.deserialize(data_as_string, bob_key, sender_key=alice_key) + assert rv["payload"] == b"hello" diff --git a/tests/jose/test_jwk.py b/tests/jose/test_jwk.py new file mode 100644 index 000000000..173d08c58 --- /dev/null +++ b/tests/jose/test_jwk.py @@ -0,0 +1,313 @@ +import pytest + +from authlib.common.encoding import base64_to_int +from authlib.common.encoding import json_dumps +from authlib.jose import ECKey +from authlib.jose import JsonWebKey +from authlib.jose import KeySet +from authlib.jose import OctKey +from authlib.jose import OKPKey +from authlib.jose import RSAKey +from tests.util import read_file_path + + +def test_oct_import_oct_key(): + # https://tools.ietf.org/html/rfc7520#section-3.5 + obj = { + "kty": "oct", + "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", + "use": "sig", + "alg": "HS256", + "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg", + } + key = OctKey.import_key(obj) + new_obj = key.as_dict() + assert obj["k"] == new_obj["k"] + assert "use" in new_obj + + +def test_oct_invalid_oct_key(): + with pytest.raises(ValueError): + OctKey.import_key({}) + + +def test_oct_generate_oct_key(): + with pytest.raises(ValueError): + OctKey.generate_key(251) + + with pytest.raises(ValueError, match="oct key can not be generated as public"): + OctKey.generate_key(is_private=False) + + key = OctKey.generate_key() + assert "kid" in key.as_dict() + assert "use" not in key.as_dict() + + key2 = OctKey.import_key(key, {"use": "sig"}) + assert "use" in key2.as_dict() + + +def test_rsa_import_ssh_pem(): + raw = read_file_path("ssh_public.pem") + key = RSAKey.import_key(raw) + obj = key.as_dict() + assert obj["kty"] == "RSA" + + +def test_rsa_public_key(): + # https://tools.ietf.org/html/rfc7520#section-3.3 + obj = read_file_path("jwk_public.json") + key = RSAKey.import_key(obj) + new_obj = key.as_dict() + assert base64_to_int(new_obj["n"]) == base64_to_int(obj["n"]) + assert base64_to_int(new_obj["e"]) == base64_to_int(obj["e"]) + + +def test_rsa_private_key(): + # https://tools.ietf.org/html/rfc7520#section-3.4 + obj = read_file_path("jwk_private.json") + key = RSAKey.import_key(obj) + new_obj = key.as_dict(is_private=True) + assert base64_to_int(new_obj["n"]) == base64_to_int(obj["n"]) + assert base64_to_int(new_obj["e"]) == base64_to_int(obj["e"]) + assert base64_to_int(new_obj["d"]) == base64_to_int(obj["d"]) + assert base64_to_int(new_obj["p"]) == base64_to_int(obj["p"]) + assert base64_to_int(new_obj["q"]) == base64_to_int(obj["q"]) + assert base64_to_int(new_obj["dp"]) == base64_to_int(obj["dp"]) + assert base64_to_int(new_obj["dq"]) == base64_to_int(obj["dq"]) + assert base64_to_int(new_obj["qi"]) == base64_to_int(obj["qi"]) + + +def test_rsa_private_key2(): + rsa_obj = read_file_path("jwk_private.json") + obj = { + "kty": "RSA", + "kid": "bilbo.baggins@hobbiton.example", + "use": "sig", + "n": rsa_obj["n"], + "d": rsa_obj["d"], + "e": "AQAB", + } + key = RSAKey.import_key(obj) + new_obj = key.as_dict(is_private=True) + assert base64_to_int(new_obj["n"]) == base64_to_int(obj["n"]) + assert base64_to_int(new_obj["e"]) == base64_to_int(obj["e"]) + assert base64_to_int(new_obj["d"]) == base64_to_int(obj["d"]) + assert base64_to_int(new_obj["p"]) == base64_to_int(rsa_obj["p"]) + assert base64_to_int(new_obj["q"]) == base64_to_int(rsa_obj["q"]) + assert base64_to_int(new_obj["dp"]) == base64_to_int(rsa_obj["dp"]) + assert base64_to_int(new_obj["dq"]) == base64_to_int(rsa_obj["dq"]) + assert base64_to_int(new_obj["qi"]) == base64_to_int(rsa_obj["qi"]) + + +def test_invalid_rsa(): + with pytest.raises(ValueError): + RSAKey.import_key({"kty": "RSA"}) + rsa_obj = read_file_path("jwk_private.json") + obj = { + "kty": "RSA", + "kid": "bilbo.baggins@hobbiton.example", + "use": "sig", + "n": rsa_obj["n"], + "d": rsa_obj["d"], + "p": rsa_obj["p"], + "e": "AQAB", + } + with pytest.raises(ValueError): + RSAKey.import_key(obj) + + +def test_rsa_key_generate(): + with pytest.raises(ValueError): + RSAKey.generate_key(256) + with pytest.raises(ValueError): + RSAKey.generate_key(2001) + + key1 = RSAKey.generate_key(is_private=True) + assert b"PRIVATE" in key1.as_pem(is_private=True) + assert b"PUBLIC" in key1.as_pem(is_private=False) + + key2 = RSAKey.generate_key(is_private=False) + with pytest.raises(ValueError): + key2.as_pem(True) + assert b"PUBLIC" in key2.as_pem(is_private=False) + + +def test_ec_public_key(): + # https://tools.ietf.org/html/rfc7520#section-3.1 + obj = read_file_path("secp521r1-public.json") + key = ECKey.import_key(obj) + new_obj = key.as_dict() + assert new_obj["crv"] == obj["crv"] + assert base64_to_int(new_obj["x"]) == base64_to_int(obj["x"]) + assert base64_to_int(new_obj["y"]) == base64_to_int(obj["y"]) + assert key.as_json()[0] == "{" + + +def test_ec_private_key(): + # https://tools.ietf.org/html/rfc7520#section-3.2 + obj = read_file_path("secp521r1-private.json") + key = ECKey.import_key(obj) + new_obj = key.as_dict(is_private=True) + assert new_obj["crv"] == obj["crv"] + assert base64_to_int(new_obj["x"]) == base64_to_int(obj["x"]) + assert base64_to_int(new_obj["y"]) == base64_to_int(obj["y"]) + assert base64_to_int(new_obj["d"]) == base64_to_int(obj["d"]) + + +def test_invalid_ec(): + with pytest.raises(ValueError): + ECKey.import_key({"kty": "EC"}) + + +def test_ec_key_generate(): + with pytest.raises(ValueError): + ECKey.generate_key("Invalid") + + key1 = ECKey.generate_key("P-384", is_private=True) + assert b"PRIVATE" in key1.as_pem(is_private=True) + assert b"PUBLIC" in key1.as_pem(is_private=False) + + key2 = ECKey.generate_key("P-256", is_private=False) + with pytest.raises(ValueError): + key2.as_pem(True) + assert b"PUBLIC" in key2.as_pem(is_private=False) + + +def test_import_okp_ssh_key(): + raw = read_file_path("ed25519-ssh.pub") + key = OKPKey.import_key(raw) + obj = key.as_dict() + assert obj["kty"] == "OKP" + assert obj["crv"] == "Ed25519" + + +def test_import_okp_public_key(): + obj = { + "x": "AD9E0JYnpV-OxZbd8aN1t4z71Vtf6JcJC7TYHT0HDbg", + "crv": "Ed25519", + "kty": "OKP", + } + key = OKPKey.import_key(obj) + new_obj = key.as_dict() + assert obj["x"] == new_obj["x"] + + +def test_import_okp_private_pem(): + raw = read_file_path("ed25519-pkcs8.pem") + key = OKPKey.import_key(raw) + obj = key.as_dict(is_private=True) + assert obj["kty"] == "OKP" + assert obj["crv"] == "Ed25519" + assert "d" in obj + + +def test_import_okp_private_dict(): + obj = { + "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo", + "d": "nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A", + "crv": "Ed25519", + "kty": "OKP", + } + key = OKPKey.import_key(obj) + new_obj = key.as_dict(is_private=True) + assert obj["d"] == new_obj["d"] + + +def test_okp_key_generate_pem(): + with pytest.raises(ValueError): + OKPKey.generate_key("invalid") + + key1 = OKPKey.generate_key("Ed25519", is_private=True) + assert b"PRIVATE" in key1.as_pem(is_private=True) + assert b"PUBLIC" in key1.as_pem(is_private=False) + + key2 = OKPKey.generate_key("X25519", is_private=False) + with pytest.raises(ValueError): + key2.as_pem(True) + assert b"PUBLIC" in key2.as_pem(is_private=False) + + +def test_jwk_generate_keys(): + key = JsonWebKey.generate_key(kty="oct", crv_or_size=256, is_private=True) + assert key["kty"] == "oct" + + key = JsonWebKey.generate_key(kty="EC", crv_or_size="P-256") + assert key["kty"] == "EC" + + key = JsonWebKey.generate_key(kty="RSA", crv_or_size=2048) + assert key["kty"] == "RSA" + + key = JsonWebKey.generate_key(kty="OKP", crv_or_size="Ed25519") + assert key["kty"] == "OKP" + + +def test_jwk_import_keys(): + rsa_pub_pem = read_file_path("rsa_public.pem") + with pytest.raises(ValueError): + JsonWebKey.import_key(rsa_pub_pem, {"kty": "EC"}) + + key = JsonWebKey.import_key(raw=rsa_pub_pem, options={"kty": "RSA"}) + assert "e" in dict(key) + assert "n" in dict(key) + + key = JsonWebKey.import_key(raw=rsa_pub_pem) + assert "e" in dict(key) + assert "n" in dict(key) + + +def test_jwk_import_key_set(): + jwks_public = read_file_path("jwks_public.json") + key_set1 = JsonWebKey.import_key_set(jwks_public) + key1 = key_set1.find_by_kid("abc") + assert key1["e"] == "AQAB" + + key_set2 = JsonWebKey.import_key_set(jwks_public["keys"]) + key2 = key_set2.find_by_kid("abc") + assert key2["e"] == "AQAB" + + key_set3 = JsonWebKey.import_key_set(json_dumps(jwks_public)) + key3 = key_set3.find_by_kid("abc") + assert key3["e"] == "AQAB" + + with pytest.raises(ValueError): + JsonWebKey.import_key_set("invalid") + + +def test_jwk_find_by_kid_with_use(): + key1 = OctKey.import_key("secret", {"kid": "abc", "use": "sig"}) + key2 = OctKey.import_key("secret", {"kid": "abc", "use": "enc"}) + + key_set = KeySet([key1, key2]) + key = key_set.find_by_kid("abc", use="sig") + assert key == key1 + + key = key_set.find_by_kid("abc", use="enc") + assert key == key2 + + +def test_jwk_find_by_kid_with_alg(): + key1 = OctKey.import_key("secret", {"kid": "abc", "alg": "HS256"}) + key2 = OctKey.import_key("secret", {"kid": "abc", "alg": "dir"}) + + key_set = KeySet([key1, key2]) + key = key_set.find_by_kid("abc", alg="HS256") + assert key == key1 + + key = key_set.find_by_kid("abc", alg="dir") + assert key == key2 + + +def test_jwk_thumbprint(): + # https://tools.ietf.org/html/rfc7638#section-3.1 + data = read_file_path("thumbprint_example.json") + key = JsonWebKey.import_key(data) + expected = "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs" + assert key.thumbprint() == expected + + +def test_jwk_key_set(): + key = RSAKey.generate_key(is_private=True) + key_set = KeySet([key]) + obj = key_set.as_dict()["keys"][0] + assert "kid" in obj + assert key_set.as_json()[0] == "{" diff --git a/tests/jose/test_jws.py b/tests/jose/test_jws.py new file mode 100644 index 000000000..8eae4b5c2 --- /dev/null +++ b/tests/jose/test_jws.py @@ -0,0 +1,316 @@ +import json + +import pytest + +from authlib.jose import JsonWebSignature +from authlib.jose import errors +from tests.util import read_file_path + + +def test_invalid_input(): + jws = JsonWebSignature() + with pytest.raises(errors.DecodeError): + jws.deserialize("a", "k") + with pytest.raises(errors.DecodeError): + jws.deserialize("a.b.c", "k") + with pytest.raises(errors.DecodeError): + jws.deserialize("YQ.YQ.YQ", "k") # a + with pytest.raises(errors.DecodeError): + jws.deserialize("W10.a.YQ", "k") # [] + with pytest.raises(errors.DecodeError): + jws.deserialize("e30.a.YQ", "k") # {} + with pytest.raises(errors.DecodeError): + jws.deserialize("eyJhbGciOiJzIn0.a.YQ", "k") + with pytest.raises(errors.DecodeError): + jws.deserialize("eyJhbGciOiJzIn0.YQ.a", "k") + + +def test_invalid_alg(): + jws = JsonWebSignature() + with pytest.raises(errors.UnsupportedAlgorithmError): + jws.deserialize( + "eyJhbGciOiJzIn0.YQ.YQ", + "k", + ) + with pytest.raises(errors.MissingAlgorithmError): + jws.serialize({}, "", "k") + with pytest.raises(errors.UnsupportedAlgorithmError): + jws.serialize({"alg": "s"}, "", "k") + + +def test_bad_signature(): + jws = JsonWebSignature() + s = "eyJhbGciOiJIUzI1NiJ9.YQ.YQ" + with pytest.raises(errors.BadSignatureError): + jws.deserialize(s, "k") + + +def test_not_supported_alg(): + jws = JsonWebSignature(algorithms=["HS256"]) + s = jws.serialize({"alg": "HS256"}, "hello", "secret") + + jws = JsonWebSignature(algorithms=["RS256"]) + with pytest.raises(errors.UnsupportedAlgorithmError): + jws.serialize({"alg": "HS256"}, "hello", "secret") + + with pytest.raises(errors.UnsupportedAlgorithmError): + jws.deserialize(s, "secret") + + +def test_compact_jws(): + jws = JsonWebSignature(algorithms=["HS256"]) + s = jws.serialize({"alg": "HS256"}, "hello", "secret") + data = jws.deserialize(s, "secret") + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "HS256" + assert "signature" not in data + + +def test_compact_rsa(): + jws = JsonWebSignature() + private_key = read_file_path("rsa_private.pem") + public_key = read_file_path("rsa_public.pem") + s = jws.serialize({"alg": "RS256"}, "hello", private_key) + data = jws.deserialize(s, public_key) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "RS256" + + # can deserialize with private key + data2 = jws.deserialize(s, private_key) + assert data == data2 + + ssh_pub_key = read_file_path("ssh_public.pem") + with pytest.raises(errors.BadSignatureError): + jws.deserialize(s, ssh_pub_key) + + +def test_compact_rsa_pss(): + jws = JsonWebSignature() + private_key = read_file_path("rsa_private.pem") + public_key = read_file_path("rsa_public.pem") + s = jws.serialize({"alg": "PS256"}, "hello", private_key) + data = jws.deserialize(s, public_key) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "PS256" + ssh_pub_key = read_file_path("ssh_public.pem") + with pytest.raises(errors.BadSignatureError): + jws.deserialize(s, ssh_pub_key) + + +def test_compact_none(): + jws = JsonWebSignature(algorithms=["none"]) + s = jws.serialize({"alg": "none"}, "hello", None) + data = jws.deserialize(s, None) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "none" + + +def test_flattened_json_jws(): + jws = JsonWebSignature() + protected = {"alg": "HS256"} + header = {"protected": protected, "header": {"kid": "a"}} + s = jws.serialize(header, "hello", "secret") + assert isinstance(s, dict) + + data = jws.deserialize(s, "secret") + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "HS256" + assert "protected" not in data + + +def test_nested_json_jws(): + jws = JsonWebSignature() + protected = {"alg": "HS256"} + header = {"protected": protected, "header": {"kid": "a"}} + s = jws.serialize([header], "hello", "secret") + assert isinstance(s, dict) + assert "signatures" in s + + data = jws.deserialize(s, "secret") + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header[0]["alg"] == "HS256" + assert "signatures" not in data + + # test bad signature + with pytest.raises(errors.BadSignatureError): + jws.deserialize(s, "f") + + +def test_function_key(): + protected = {"alg": "HS256"} + header = [ + {"protected": protected, "header": {"kid": "a"}}, + {"protected": protected, "header": {"kid": "b"}}, + ] + + def load_key(header, payload): + assert payload == b"hello" + kid = header.get("kid") + if kid == "a": + return "secret-a" + return "secret-b" + + jws = JsonWebSignature() + s = jws.serialize(header, b"hello", load_key) + assert isinstance(s, dict) + assert "signatures" in s + + data = jws.deserialize(json.dumps(s), load_key) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header[0]["alg"] == "HS256" + assert "signature" not in data + + +def test_serialize_json_empty_payload(): + jws = JsonWebSignature() + protected = {"alg": "HS256"} + header = {"protected": protected, "header": {"kid": "a"}} + s = jws.serialize_json(header, b"", "secret") + data = jws.deserialize_json(s, "secret") + assert data["payload"] == b"" + + +def test_fail_deserialize_json(): + jws = JsonWebSignature() + with pytest.raises(errors.DecodeError): + jws.deserialize_json(None, "") + with pytest.raises(errors.DecodeError): + jws.deserialize_json("[]", "") + with pytest.raises(errors.DecodeError): + jws.deserialize_json("{}", "") + + # missing protected + s = json.dumps({"payload": "YQ"}) + with pytest.raises(errors.DecodeError): + jws.deserialize_json(s, "") + + # missing signature + s = json.dumps({"payload": "YQ", "protected": "YQ"}) + with pytest.raises(errors.DecodeError): + jws.deserialize_json(s, "") + + +def test_serialize_json_overwrite_header(): + jws = JsonWebSignature() + protected = {"alg": "HS256", "kid": "a"} + header = {"protected": protected} + result = jws.serialize_json(header, b"", "secret") + result["header"] = {"kid": "b"} + decoded = jws.deserialize_json(result, "secret") + assert decoded["header"]["kid"] == "a" + + +def test_validate_header(): + jws = JsonWebSignature(private_headers=[]) + protected = {"alg": "HS256", "invalid": "k"} + header = {"protected": protected, "header": {"kid": "a"}} + with pytest.raises(errors.InvalidHeaderParameterNameError): + jws.serialize( + header, + b"hello", + "secret", + ) + jws = JsonWebSignature(private_headers=["invalid"]) + s = jws.serialize(header, b"hello", "secret") + assert isinstance(s, dict) + + jws = JsonWebSignature() + s = jws.serialize(header, b"hello", "secret") + assert isinstance(s, dict) + + +def test_validate_crit_header_with_serialize(): + jws = JsonWebSignature() + protected = {"alg": "HS256", "kid": "1", "crit": ["kid"]} + jws.serialize(protected, b"hello", "secret") + + protected = {"alg": "HS256", "crit": ["kid"]} + with pytest.raises(errors.InvalidCritHeaderParameterNameError): + jws.serialize(protected, b"hello", "secret") + + protected = {"alg": "HS256", "invalid": "1", "crit": ["invalid"]} + with pytest.raises(errors.InvalidCritHeaderParameterNameError): + jws.serialize(protected, b"hello", "secret") + + +def test_validate_crit_header_with_deserialize(): + jws = JsonWebSignature() + case1 = "eyJhbGciOiJIUzI1NiIsImNyaXQiOlsia2lkIl19.aGVsbG8.RVimhJH2LRGAeHy0ZcbR9xsgKhzhxIBkHs7S_TDgWvc" + with pytest.raises(errors.InvalidCritHeaderParameterNameError): + jws.deserialize(case1, "secret") + + case2 = ( + "eyJhbGciOiJIUzI1NiIsImludmFsaWQiOiIxIiwiY3JpdCI6WyJpbnZhbGlkIl19." + "aGVsbG8.ifW_D1AQWzggrpd8npcnmpiwMD9dp5FTX66lCkYFENM" + ) + with pytest.raises(errors.InvalidCritHeaderParameterNameError): + jws.deserialize(case2, "secret") + + +def test_unprotected_crit_rejected_in_json_serialize(): + jws = JsonWebSignature() + protected = {"alg": "HS256", "kid": "a"} + # Place 'crit' in unprotected header; must be rejected + header = {"protected": protected, "header": {"kid": "a", "crit": ["kid"]}} + with pytest.raises(errors.InvalidHeaderParameterNameError): + jws.serialize_json(header, b"hello", "secret") + + +def test_unprotected_crit_rejected_in_json_deserialize(): + jws = JsonWebSignature() + protected = {"alg": "HS256", "kid": "a"} + header = {"protected": protected, "header": {"kid": "a"}} + data = jws.serialize_json(header, b"hello", "secret") + # Tamper by adding 'crit' into the unprotected header; must be rejected + data_tampered = dict(data) + data_tampered["header"] = {"kid": "a", "crit": ["kid"]} + with pytest.raises(errors.InvalidHeaderParameterNameError): + jws.deserialize_json(data_tampered, "secret") + + +def test_ES512_alg(): + jws = JsonWebSignature() + private_key = read_file_path("secp521r1-private.json") + public_key = read_file_path("secp521r1-public.json") + with pytest.raises(ValueError): + jws.serialize({"alg": "ES256"}, "hello", private_key) + s = jws.serialize({"alg": "ES512"}, "hello", private_key) + data = jws.deserialize(s, public_key) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "ES512" + + +def test_ES256K_alg(): + jws = JsonWebSignature(algorithms=["ES256K"]) + private_key = read_file_path("secp256k1-private.pem") + public_key = read_file_path("secp256k1-pub.pem") + s = jws.serialize({"alg": "ES256K"}, "hello", private_key) + data = jws.deserialize(s, public_key) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "ES256K" + + +def test_deserialize_exceeds_length(): + jws = JsonWebSignature() + value = "aa" * 256000 + + # header exceeds length + with pytest.raises(ValueError): + jws.deserialize(value + "." + value + "." + value, "") + + # payload exceeds length + with pytest.raises(ValueError): + jws.deserialize("eyJhbGciOiJIUzI1NiJ9." + value + "." + value, "") + + # signature exceeds length + with pytest.raises(ValueError): + jws.deserialize("eyJhbGciOiJIUzI1NiJ9.YQ." + value, "") diff --git a/tests/jose/test_jwt.py b/tests/jose/test_jwt.py new file mode 100644 index 000000000..c8da110e7 --- /dev/null +++ b/tests/jose/test_jwt.py @@ -0,0 +1,265 @@ +import datetime + +import pytest + +from authlib.jose import JsonWebKey +from authlib.jose import JsonWebToken +from authlib.jose import JWTClaims +from authlib.jose import errors +from authlib.jose import jwt +from authlib.jose.errors import UnsupportedAlgorithmError +from tests.util import read_file_path + + +def test_init_algorithms(): + _jwt = JsonWebToken(["RS256"]) + with pytest.raises(UnsupportedAlgorithmError): + _jwt.encode({"alg": "HS256"}, {}, "k") + + _jwt = JsonWebToken("RS256") + with pytest.raises(UnsupportedAlgorithmError): + _jwt.encode({"alg": "HS256"}, {}, "k") + + +def test_encode_sensitive_data(): + # check=False won't raise error + jwt.encode({"alg": "HS256"}, {"password": ""}, "k", check=False) + with pytest.raises(errors.InsecureClaimError): + jwt.encode( + {"alg": "HS256"}, + {"password": ""}, + "k", + ) + with pytest.raises(errors.InsecureClaimError): + jwt.encode( + {"alg": "HS256"}, + {"text": "4242424242424242"}, + "k", + ) + + +def test_encode_datetime(): + now = datetime.datetime.now(tz=datetime.timezone.utc) + id_token = jwt.encode({"alg": "HS256"}, {"exp": now}, "k") + claims = jwt.decode(id_token, "k") + assert isinstance(claims.exp, int) + + +def test_validate_essential_claims(): + id_token = jwt.encode({"alg": "HS256"}, {"iss": "foo"}, "k") + claims_options = {"iss": {"essential": True, "values": ["foo"]}} + claims = jwt.decode(id_token, "k", claims_options=claims_options) + claims.validate() + + claims.options = {"sub": {"essential": True}} + with pytest.raises(errors.MissingClaimError): + claims.validate() + + +def test_attribute_error(): + claims = JWTClaims({"iss": "foo"}, {"alg": "HS256"}) + with pytest.raises(AttributeError): + claims.invalid # noqa: B018 + + +def test_invalid_values(): + id_token = jwt.encode({"alg": "HS256"}, {"iss": "foo"}, "k") + claims_options = {"iss": {"values": ["bar"]}} + claims = jwt.decode(id_token, "k", claims_options=claims_options) + with pytest.raises(errors.InvalidClaimError): + claims.validate() + claims.options = {"iss": {"value": "bar"}} + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + +def test_validate_expected_issuer_received_None(): + id_token = jwt.encode({"alg": "HS256"}, {"iss": None, "sub": None}, "k") + claims_options = {"iss": {"essential": True, "values": ["foo"]}} + claims = jwt.decode(id_token, "k", claims_options=claims_options) + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + +def test_validate_aud(): + id_token = jwt.encode({"alg": "HS256"}, {"aud": "foo"}, "k") + claims_options = {"aud": {"essential": True, "value": "foo"}} + claims = jwt.decode(id_token, "k", claims_options=claims_options) + claims.validate() + + claims.options = {"aud": {"values": ["bar"]}} + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + id_token = jwt.encode({"alg": "HS256"}, {"aud": ["foo", "bar"]}, "k") + claims = jwt.decode(id_token, "k", claims_options=claims_options) + claims.validate() + # no validate + claims.options = {"aud": {"values": []}} + claims.validate() + + +def test_validate_exp(): + id_token = jwt.encode({"alg": "HS256"}, {"exp": "invalid"}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + id_token = jwt.encode({"alg": "HS256"}, {"exp": 1234}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises(errors.ExpiredTokenError): + claims.validate() + + +def test_validate_nbf(): + id_token = jwt.encode({"alg": "HS256"}, {"nbf": "invalid"}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + id_token = jwt.encode({"alg": "HS256"}, {"nbf": 1234}, "k") + claims = jwt.decode(id_token, "k") + claims.validate() + + id_token = jwt.encode({"alg": "HS256"}, {"nbf": 1234}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises(errors.InvalidTokenError): + claims.validate(123) + + +def test_validate_iat_issued_in_future(): + in_future = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=10 + ) + id_token = jwt.encode({"alg": "HS256"}, {"iat": in_future}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises( + errors.InvalidTokenError, + match="The token is not valid as it was issued in the future", + ): + claims.validate() + + +def test_validate_iat_issued_in_future_with_insufficient_leeway(): + in_future = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=10 + ) + id_token = jwt.encode({"alg": "HS256"}, {"iat": in_future}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises( + errors.InvalidTokenError, + match="The token is not valid as it was issued in the future", + ): + claims.validate(leeway=5) + + +def test_validate_iat_issued_in_future_with_sufficient_leeway(): + in_future = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=10 + ) + id_token = jwt.encode({"alg": "HS256"}, {"iat": in_future}, "k") + claims = jwt.decode(id_token, "k") + claims.validate(leeway=20) + + +def test_validate_iat_issued_in_past(): + in_future = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( + seconds=10 + ) + id_token = jwt.encode({"alg": "HS256"}, {"iat": in_future}, "k") + claims = jwt.decode(id_token, "k") + claims.validate() + + +def test_validate_iat(): + id_token = jwt.encode({"alg": "HS256"}, {"iat": "invalid"}, "k") + claims = jwt.decode(id_token, "k") + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + +def test_validate_jti(): + id_token = jwt.encode({"alg": "HS256"}, {"jti": "bar"}, "k") + claims_options = {"jti": {"validate": lambda c, o: o == "foo"}} + claims = jwt.decode(id_token, "k", claims_options=claims_options) + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + +def test_validate_custom(): + id_token = jwt.encode({"alg": "HS256"}, {"custom": "foo"}, "k") + claims_options = {"custom": {"validate": lambda c, o: o == "bar"}} + claims = jwt.decode(id_token, "k", claims_options=claims_options) + with pytest.raises(errors.InvalidClaimError): + claims.validate() + + +def test_use_jws(): + payload = {"name": "hi"} + private_key = read_file_path("rsa_private.pem") + pub_key = read_file_path("rsa_public.pem") + data = jwt.encode({"alg": "RS256"}, payload, private_key) + assert data.count(b".") == 2 + + claims = jwt.decode(data, pub_key) + assert claims["name"] == "hi" + + +def test_use_jwe(): + payload = {"name": "hi"} + private_key = read_file_path("rsa_private.pem") + pub_key = read_file_path("rsa_public.pem") + _jwt = JsonWebToken(["RSA-OAEP", "A256GCM"]) + data = _jwt.encode({"alg": "RSA-OAEP", "enc": "A256GCM"}, payload, pub_key) + assert data.count(b".") == 4 + + claims = _jwt.decode(data, private_key) + assert claims["name"] == "hi" + + +def test_use_jwks(): + header = {"alg": "RS256", "kid": "abc"} + payload = {"name": "hi"} + private_key = read_file_path("jwks_private.json") + pub_key = read_file_path("jwks_public.json") + data = jwt.encode(header, payload, private_key) + assert data.count(b".") == 2 + claims = jwt.decode(data, pub_key) + assert claims["name"] == "hi" + + +def test_use_jwks_single_kid(): + """Test that jwks can be decoded if a kid for decoding is given and encoded data has no kid and only one key is set.""" + header = {"alg": "RS256"} + payload = {"name": "hi"} + private_key = read_file_path("jwks_single_private.json") + pub_key = read_file_path("jwks_single_public.json") + data = jwt.encode(header, payload, private_key) + assert data.count(b".") == 2 + claims = jwt.decode(data, pub_key) + assert claims["name"] == "hi" + + +# Added a unit test to showcase my problem. +# This calls jwt.decode similarly as is done in parse_id_token method of the AsyncOpenIDMixin class when the id token does not contain a kid in the alg header. +def test_use_jwks_single_kid_keyset(): + """Test that jwks can be decoded if a kid for decoding is given and encoded data has no kid and a keyset with one key.""" + header = {"alg": "RS256"} + payload = {"name": "hi"} + private_key = read_file_path("jwks_single_private.json") + pub_key = read_file_path("jwks_single_public.json") + data = jwt.encode(header, payload, private_key) + assert data.count(b".") == 2 + claims = jwt.decode(data, JsonWebKey.import_key_set(pub_key)) + assert claims["name"] == "hi" + + +def test_with_ec(): + payload = {"name": "hi"} + private_key = read_file_path("secp521r1-private.json") + pub_key = read_file_path("secp521r1-public.json") + data = jwt.encode({"alg": "ES512"}, payload, private_key) + assert data.count(b".") == 2 + + claims = jwt.decode(data, pub_key) + assert claims["name"] == "hi" diff --git a/tests/jose/test_rfc8037.py b/tests/jose/test_rfc8037.py new file mode 100644 index 000000000..47d69926e --- /dev/null +++ b/tests/jose/test_rfc8037.py @@ -0,0 +1,13 @@ +from authlib.jose import JsonWebSignature +from tests.util import read_file_path + + +def test_EdDSA_alg(): + jws = JsonWebSignature(algorithms=["EdDSA"]) + private_key = read_file_path("ed25519-pkcs8.pem") + public_key = read_file_path("ed25519-pub.pem") + s = jws.serialize({"alg": "EdDSA"}, "hello", private_key) + data = jws.deserialize(s, public_key) + header, payload = data["header"], data["payload"] + assert payload == b"hello" + assert header["alg"] == "EdDSA" diff --git a/tests/py3/test_httpx_client/test_assertion_client.py b/tests/py3/test_httpx_client/test_assertion_client.py deleted file mode 100644 index 91b05297d..000000000 --- a/tests/py3/test_httpx_client/test_assertion_client.py +++ /dev/null @@ -1,65 +0,0 @@ -import time -import pytest -from authlib.integrations.httpx_client import AssertionClient -from tests.py3.utils import MockDispatch - - -default_token = { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, -} - - -@pytest.mark.asyncio -def test_refresh_token(): - def verifier(request): - content = request.form - if str(request.url) == 'https://i.b/token': - assert 'assertion' in content - - with AssertionClient( - 'https://i.b/token', - grant_type=AssertionClient.JWT_BEARER_GRANT_TYPE, - issuer='foo', - subject='foo', - audience='foo', - alg='HS256', - key='secret', - app=MockDispatch(default_token, assert_func=verifier) - ) as client: - client.get('https://i.b') - - # trigger more case - now = int(time.time()) - with AssertionClient( - 'https://i.b/token', - issuer='foo', - subject=None, - audience='foo', - issued_at=now, - expires_at=now + 3600, - header={'alg': 'HS256'}, - key='secret', - scope='email', - claims={'test_mode': 'true'}, - app=MockDispatch(default_token, assert_func=verifier) - ) as client: - client.get('https://i.b') - client.get('https://i.b') - - -@pytest.mark.asyncio -def test_without_alg(): - with AssertionClient( - 'https://i.b/token', - issuer='foo', - subject='foo', - audience='foo', - key='secret', - app=MockDispatch(default_token) - ) as client: - with pytest.raises(ValueError): - client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_assertion_client.py b/tests/py3/test_httpx_client/test_async_assertion_client.py deleted file mode 100644 index 46286bff3..000000000 --- a/tests/py3/test_httpx_client/test_async_assertion_client.py +++ /dev/null @@ -1,65 +0,0 @@ -import time -import pytest -from authlib.integrations.httpx_client import AsyncAssertionClient -from tests.py3.utils import AsyncMockDispatch - - -default_token = { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, -} - - -@pytest.mark.asyncio -async def test_refresh_token(): - async def verifier(request): - content = await request.body() - if str(request.url) == 'https://i.b/token': - assert b'assertion=' in content - - async with AsyncAssertionClient( - 'https://i.b/token', - grant_type=AsyncAssertionClient.JWT_BEARER_GRANT_TYPE, - issuer='foo', - subject='foo', - audience='foo', - alg='HS256', - key='secret', - app=AsyncMockDispatch(default_token, assert_func=verifier) - ) as client: - await client.get('https://i.b') - - # trigger more case - now = int(time.time()) - async with AsyncAssertionClient( - 'https://i.b/token', - issuer='foo', - subject=None, - audience='foo', - issued_at=now, - expires_at=now + 3600, - header={'alg': 'HS256'}, - key='secret', - scope='email', - claims={'test_mode': 'true'}, - app=AsyncMockDispatch(default_token, assert_func=verifier) - ) as client: - await client.get('https://i.b') - await client.get('https://i.b') - - -@pytest.mark.asyncio -async def test_without_alg(): - async with AsyncAssertionClient( - 'https://i.b/token', - issuer='foo', - subject='foo', - audience='foo', - key='secret', - app=AsyncMockDispatch() - ) as client: - with pytest.raises(ValueError): - await client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_oauth1_client.py b/tests/py3/test_httpx_client/test_async_oauth1_client.py deleted file mode 100644 index 757035678..000000000 --- a/tests/py3/test_httpx_client/test_async_oauth1_client.py +++ /dev/null @@ -1,157 +0,0 @@ -import pytest -from authlib.integrations.httpx_client import ( - OAuthError, - AsyncOAuth1Client, - SIGNATURE_TYPE_BODY, - SIGNATURE_TYPE_QUERY, -) -from tests.py3.utils import AsyncMockDispatch - -oauth_url = 'https://example.com/oauth' - - -@pytest.mark.asyncio -async def test_fetch_request_token_via_header(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - async def assert_func(request): - auth_header = request.headers.get('authorization') - assert 'oauth_consumer_key="id"' in auth_header - assert 'oauth_signature=' in auth_header - - app = AsyncMockDispatch(request_token, assert_func=assert_func) - async with AsyncOAuth1Client('id', 'secret', app=app) as client: - response = await client.fetch_request_token(oauth_url) - - assert response == request_token - - -@pytest.mark.asyncio -async def test_fetch_request_token_via_body(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - async def assert_func(request): - auth_header = request.headers.get('authorization') - assert auth_header is None - - content = await request.body() - assert b'oauth_consumer_key=id' in content - assert b'&oauth_signature=' in content - - mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) - - async with AsyncOAuth1Client( - 'id', 'secret', signature_type=SIGNATURE_TYPE_BODY, - app=mock_response, - ) as client: - response = await client.fetch_request_token(oauth_url) - - assert response == request_token - - -@pytest.mark.asyncio -async def test_fetch_request_token_via_query(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - async def assert_func(request): - auth_header = request.headers.get('authorization') - assert auth_header is None - - url = str(request.url) - assert 'oauth_consumer_key=id' in url - assert '&oauth_signature=' in url - - mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) - - async with AsyncOAuth1Client( - 'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY, - app=mock_response, - ) as client: - response = await client.fetch_request_token(oauth_url) - - assert response == request_token - - -@pytest.mark.asyncio -async def test_fetch_access_token(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - async def assert_func(request): - auth_header = request.headers.get('authorization') - assert 'oauth_verifier="d"' in auth_header - assert 'oauth_token="foo"' in auth_header - assert 'oauth_consumer_key="id"' in auth_header - assert 'oauth_signature=' in auth_header - - mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) - async with AsyncOAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - app=mock_response, - ) as client: - with pytest.raises(OAuthError): - await client.fetch_access_token(oauth_url) - - response = await client.fetch_access_token(oauth_url, verifier='d') - - assert response == request_token - - -@pytest.mark.asyncio -async def test_get_via_header(): - mock_response = AsyncMockDispatch(b'hello') - async with AsyncOAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - app=mock_response, - ) as client: - response = await client.get('https://example.com/') - - assert response.content == b'hello' - request = response.request - auth_header = request.headers.get('authorization') - assert 'oauth_token="foo"' in auth_header - assert 'oauth_consumer_key="id"' in auth_header - assert 'oauth_signature=' in auth_header - - -@pytest.mark.asyncio -async def test_get_via_body(): - async def assert_func(request): - content = await request.body() - assert b'oauth_token=foo' in content - assert b'oauth_consumer_key=id' in content - assert b'oauth_signature=' in content - - mock_response = AsyncMockDispatch(b'hello', assert_func=assert_func) - async with AsyncOAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - signature_type=SIGNATURE_TYPE_BODY, - app=mock_response, - ) as client: - response = await client.post('https://example.com/') - - assert response.content == b'hello' - - request = response.request - auth_header = request.headers.get('authorization') - assert auth_header is None - - -@pytest.mark.asyncio -async def test_get_via_query(): - mock_response = AsyncMockDispatch(b'hello') - async with AsyncOAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - signature_type=SIGNATURE_TYPE_QUERY, - app=mock_response, - ) as client: - response = await client.get('https://example.com/') - - assert response.content == b'hello' - request = response.request - auth_header = request.headers.get('authorization') - assert auth_header is None - - url = str(request.url) - assert 'oauth_token=foo' in url - assert 'oauth_consumer_key=id' in url - assert 'oauth_signature=' in url diff --git a/tests/py3/test_httpx_client/test_async_oauth2_client.py b/tests/py3/test_httpx_client/test_async_oauth2_client.py deleted file mode 100644 index 2333d2e58..000000000 --- a/tests/py3/test_httpx_client/test_async_oauth2_client.py +++ /dev/null @@ -1,416 +0,0 @@ -import asyncio -import mock -import time -import pytest -from copy import deepcopy -from authlib.common.security import generate_token -from authlib.common.urls import url_encode -from authlib.integrations.httpx_client import ( - OAuthError, - AsyncOAuth2Client, -) -from tests.py3.utils import AsyncMockDispatch - - -default_token = { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, -} - - -@pytest.mark.asyncio -async def test_add_token_to_header(): - async def assert_func(request): - token = 'Bearer ' + default_token['access_token'] - auth_header = request.headers.get('authorization') - assert auth_header == token - - mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) - async with AsyncOAuth2Client( - 'foo', - token=default_token, - app=mock_response - ) as client: - resp = await client.get('https://i.b') - - data = resp.json() - assert data['a'] == 'a' - - -@pytest.mark.asyncio -async def test_add_token_to_body(): - async def assert_func(request): - content = await request.body() - assert default_token['access_token'] in content.decode() - - mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) - async with AsyncOAuth2Client( - 'foo', - token=default_token, - token_placement='body', - app=mock_response - ) as client: - resp = await client.get('https://i.b') - - data = resp.json() - assert data['a'] == 'a' - - -@pytest.mark.asyncio -async def test_add_token_to_uri(): - async def assert_func(request): - assert default_token['access_token'] in str(request.url) - - mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) - async with AsyncOAuth2Client( - 'foo', - token=default_token, - token_placement='uri', - app=mock_response - ) as client: - resp = await client.get('https://i.b') - - data = resp.json() - assert data['a'] == 'a' - - -def test_create_authorization_url(): - url = 'https://example.com/authorize?foo=bar' - - sess = AsyncOAuth2Client(client_id='foo') - auth_url, state = sess.create_authorization_url(url) - assert state in auth_url - assert 'client_id=foo' in auth_url - assert 'response_type=code' in auth_url - - sess = AsyncOAuth2Client(client_id='foo', prompt='none') - auth_url, state = sess.create_authorization_url( - url, state='foo', redirect_uri='https://i.b', scope='profile') - assert state == 'foo' - assert 'i.b' in auth_url - assert 'profile' in auth_url - assert 'prompt=none' in auth_url - - -def test_code_challenge(): - sess = AsyncOAuth2Client('foo', code_challenge_method='S256') - - url = 'https://example.com/authorize' - auth_url, _ = sess.create_authorization_url( - url, code_verifier=generate_token(48)) - assert 'code_challenge=' in auth_url - assert 'code_challenge_method=S256' in auth_url - - -def test_token_from_fragment(): - sess = AsyncOAuth2Client('foo') - response_url = 'https://i.b/callback#' + url_encode(default_token.items()) - assert sess.token_from_fragment(response_url) == default_token - token = sess.fetch_token(authorization_response=response_url) - assert token == default_token - - -@pytest.mark.asyncio -async def test_fetch_token_post(): - url = 'https://example.com/token' - - async def assert_func(request): - content = await request.body() - content = content.decode() - assert 'code=v' in content - assert 'client_id=' in content - assert 'grant_type=authorization_code' in content - - mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) - async with AsyncOAuth2Client('foo', app=mock_response) as client: - token = await client.fetch_token(url, authorization_response='https://i.b/?code=v') - assert token == default_token - - async with AsyncOAuth2Client( - 'foo', - token_endpoint_auth_method='none', - app=mock_response - ) as client: - token = await client.fetch_token(url, code='v') - assert token == default_token - - mock_response = AsyncMockDispatch({'error': 'invalid_request'}) - async with AsyncOAuth2Client('foo', app=mock_response) as client: - with pytest.raises(OAuthError): - await client.fetch_token(url) - - -@pytest.mark.asyncio -async def test_fetch_token_get(): - url = 'https://example.com/token' - - async def assert_func(request): - url = str(request.url) - assert 'code=v' in url - assert 'client_id=' in url - assert 'grant_type=authorization_code' in url - - mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) - async with AsyncOAuth2Client('foo', app=mock_response) as client: - authorization_response = 'https://i.b/?code=v' - token = await client.fetch_token( - url, authorization_response=authorization_response, method='GET') - assert token == default_token - - async with AsyncOAuth2Client( - 'foo', - token_endpoint_auth_method='none', - app=mock_response - ) as client: - token = await client.fetch_token(url, code='v', method='GET') - assert token == default_token - - token = await client.fetch_token(url + '?q=a', code='v', method='GET') - assert token == default_token - - -@pytest.mark.asyncio -async def test_token_auth_method_client_secret_post(): - url = 'https://example.com/token' - - async def assert_func(request): - content = await request.body() - content = content.decode() - assert 'code=v' in content - assert 'client_id=' in content - assert 'client_secret=bar' in content - assert 'grant_type=authorization_code' in content - - mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) - async with AsyncOAuth2Client( - 'foo', 'bar', - token_endpoint_auth_method='client_secret_post', - app=mock_response - ) as client: - token = await client.fetch_token(url, code='v') - - assert token == default_token - - -@pytest.mark.asyncio -async def test_access_token_response_hook(): - url = 'https://example.com/token' - - def _access_token_response_hook(resp): - assert resp.json() == default_token - return resp - - access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) - app = AsyncMockDispatch(default_token) - async with AsyncOAuth2Client('foo', token=default_token, app=app) as sess: - sess.register_compliance_hook( - 'access_token_response', - access_token_response_hook - ) - assert await sess.fetch_token(url) == default_token - assert access_token_response_hook.called is True - - -@pytest.mark.asyncio -async def test_password_grant_type(): - url = 'https://example.com/token' - - async def assert_func(request): - content = await request.body() - content = content.decode() - assert 'username=v' in content - assert 'scope=profile' in content - assert 'grant_type=password' in content - - app = AsyncMockDispatch(default_token, assert_func=assert_func) - async with AsyncOAuth2Client('foo', scope='profile', app=app) as sess: - token = await sess.fetch_token(url, username='v', password='v') - assert token == default_token - - token = await sess.fetch_token( - url, username='v', password='v', grant_type='password') - assert token == default_token - - -@pytest.mark.asyncio -async def test_client_credentials_type(): - url = 'https://example.com/token' - - async def assert_func(request): - content = await request.body() - content = content.decode() - assert 'scope=profile' in content - assert 'grant_type=client_credentials' in content - - app = AsyncMockDispatch(default_token, assert_func=assert_func) - async with AsyncOAuth2Client('foo', scope='profile', app=app) as sess: - token = await sess.fetch_token(url) - assert token == default_token - - token = await sess.fetch_token(url, grant_type='client_credentials') - assert token == default_token - - -@pytest.mark.asyncio -async def test_cleans_previous_token_before_fetching_new_one(): - now = int(time.time()) - new_token = deepcopy(default_token) - past = now - 7200 - default_token['expires_at'] = past - new_token['expires_at'] = now + 3600 - url = 'https://example.com/token' - - app = AsyncMockDispatch(new_token) - with mock.patch('time.time', lambda: now): - async with AsyncOAuth2Client('foo', token=default_token, app=app) as sess: - assert await sess.fetch_token(url) == new_token - - -def test_token_status(): - token = dict(access_token='a', token_type='bearer', expires_at=100) - sess = AsyncOAuth2Client('foo', token=token) - assert sess.token.is_expired() is True - - -@pytest.mark.asyncio -async def test_auto_refresh_token(): - - async def _update_token(token, refresh_token=None, access_token=None): - assert refresh_token == 'b' - assert token == default_token - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', refresh_token='b', - token_type='bearer', expires_at=100 - ) - - app = AsyncMockDispatch(default_token) - async with AsyncOAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, app=app - ) as sess: - await sess.get('https://i.b/user') - assert update_token.called is True - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - async with AsyncOAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, app=app - ) as sess: - with pytest.raises(OAuthError): - await sess.get('https://i.b/user') - - -@pytest.mark.asyncio -async def test_auto_refresh_token2(): - - async def _update_token(token, refresh_token=None, access_token=None): - assert access_token == 'a' - assert token == default_token - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - - app = AsyncMockDispatch(default_token) - - async with AsyncOAuth2Client( - 'foo', token=old_token, - token_endpoint='https://i.b/token', - grant_type='client_credentials', - app=app, - ) as client: - await client.get('https://i.b/user') - assert update_token.called is False - - async with AsyncOAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, grant_type='client_credentials', - app=app, - ) as client: - await client.get('https://i.b/user') - assert update_token.called is True - - -@pytest.mark.asyncio -async def test_auto_refresh_token3(): - async def _update_token(token, refresh_token=None, access_token=None): - assert access_token == 'a' - assert token == default_token - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - - app = AsyncMockDispatch(default_token) - - async with AsyncOAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, grant_type='client_credentials', - app=app, - ) as client: - await client.post('https://i.b/user', json={'foo': 'bar'}) - assert update_token.called is True - -@pytest.mark.asyncio -async def test_auto_refresh_token4(): - async def _update_token(token, refresh_token=None, access_token=None): - await asyncio.sleep(0.1) # artificial sleep to force other coroutines to wake - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - - app = AsyncMockDispatch(default_token) - - async with AsyncOAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, grant_type='client_credentials', - app=app, - ) as client: - coroutines = [client.get('https://i.b/user') for x in range(10)] - await asyncio.gather(*coroutines) - update_token.assert_called_once() - -@pytest.mark.asyncio -async def test_revoke_token(): - answer = {'status': 'ok'} - app = AsyncMockDispatch(answer) - - async with AsyncOAuth2Client('a', app=app) as sess: - resp = await sess.revoke_token('https://i.b/token', 'hi') - assert resp.json() == answer - - resp = await sess.revoke_token( - 'https://i.b/token', 'hi', - token_type_hint='access_token' - ) - assert resp.json() == answer - - -@pytest.mark.asyncio -async def test_request_without_token(): - async with AsyncOAuth2Client('a', app=AsyncMockDispatch()) as client: - with pytest.raises(OAuthError): - await client.get('https://i.b/token') diff --git a/tests/py3/test_httpx_client/test_oauth1_client.py b/tests/py3/test_httpx_client/test_oauth1_client.py deleted file mode 100644 index a5f34df3e..000000000 --- a/tests/py3/test_httpx_client/test_oauth1_client.py +++ /dev/null @@ -1,157 +0,0 @@ -import pytest -from authlib.integrations.httpx_client import ( - OAuthError, - OAuth1Client, - SIGNATURE_TYPE_BODY, - SIGNATURE_TYPE_QUERY, -) -from tests.py3.utils import MockDispatch - -oauth_url = 'https://example.com/oauth' - - -@pytest.mark.asyncio -def test_fetch_request_token_via_header(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - def assert_func(request): - auth_header = request.headers.get('authorization') - assert 'oauth_consumer_key="id"' in auth_header - assert 'oauth_signature=' in auth_header - - app = MockDispatch(request_token, assert_func=assert_func) - with OAuth1Client('id', 'secret', app=app) as client: - response = client.fetch_request_token(oauth_url) - - assert response == request_token - - -@pytest.mark.asyncio -def test_fetch_request_token_via_body(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - def assert_func(request): - auth_header = request.headers.get('authorization') - assert auth_header is None - - content = request.form - assert content.get('oauth_consumer_key') == 'id' - assert 'oauth_signature' in content - - mock_response = MockDispatch(request_token, assert_func=assert_func) - - with OAuth1Client( - 'id', 'secret', signature_type=SIGNATURE_TYPE_BODY, - app=mock_response, - ) as client: - response = client.fetch_request_token(oauth_url) - - assert response == request_token - - -@pytest.mark.asyncio -def test_fetch_request_token_via_query(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - def assert_func(request): - auth_header = request.headers.get('authorization') - assert auth_header is None - - url = str(request.url) - assert 'oauth_consumer_key=id' in url - assert '&oauth_signature=' in url - - mock_response = MockDispatch(request_token, assert_func=assert_func) - - with OAuth1Client( - 'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY, - app=mock_response, - ) as client: - response = client.fetch_request_token(oauth_url) - - assert response == request_token - - -@pytest.mark.asyncio -def test_fetch_access_token(): - request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} - - def assert_func(request): - auth_header = request.headers.get('authorization') - assert 'oauth_verifier="d"' in auth_header - assert 'oauth_token="foo"' in auth_header - assert 'oauth_consumer_key="id"' in auth_header - assert 'oauth_signature=' in auth_header - - mock_response = MockDispatch(request_token, assert_func=assert_func) - with OAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - app=mock_response, - ) as client: - with pytest.raises(OAuthError): - client.fetch_access_token(oauth_url) - - response = client.fetch_access_token(oauth_url, verifier='d') - - assert response == request_token - - -@pytest.mark.asyncio -def test_get_via_header(): - mock_response = MockDispatch(b'hello') - with OAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - app=mock_response, - ) as client: - response = client.get('https://example.com/') - - assert response.content == b'hello' - request = response.request - auth_header = request.headers.get('authorization') - assert 'oauth_token="foo"' in auth_header - assert 'oauth_consumer_key="id"' in auth_header - assert 'oauth_signature=' in auth_header - - -@pytest.mark.asyncio -def test_get_via_body(): - def assert_func(request): - content = request.form - assert content.get('oauth_token') == 'foo' - assert content.get('oauth_consumer_key') == 'id' - assert 'oauth_signature' in content - - mock_response = MockDispatch(b'hello', assert_func=assert_func) - with OAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - signature_type=SIGNATURE_TYPE_BODY, - app=mock_response, - ) as client: - response = client.post('https://example.com/') - - assert response.content == b'hello' - - request = response.request - auth_header = request.headers.get('authorization') - assert auth_header is None - - -@pytest.mark.asyncio -def test_get_via_query(): - mock_response = MockDispatch(b'hello') - with OAuth1Client( - 'id', 'secret', token='foo', token_secret='bar', - signature_type=SIGNATURE_TYPE_QUERY, - app=mock_response, - ) as client: - response = client.get('https://example.com/') - - assert response.content == b'hello' - request = response.request - auth_header = request.headers.get('authorization') - assert auth_header is None - - url = str(request.url) - assert 'oauth_token=foo' in url - assert 'oauth_consumer_key=id' in url - assert 'oauth_signature=' in url diff --git a/tests/py3/test_httpx_client/test_oauth2_client.py b/tests/py3/test_httpx_client/test_oauth2_client.py deleted file mode 100644 index 7bd39387e..000000000 --- a/tests/py3/test_httpx_client/test_oauth2_client.py +++ /dev/null @@ -1,374 +0,0 @@ -import mock -import time -import pytest -from copy import deepcopy -from authlib.common.security import generate_token -from authlib.common.urls import url_encode -from authlib.integrations.httpx_client import ( - OAuthError, - OAuth2Client, -) -from tests.py3.utils import MockDispatch - - -default_token = { - 'token_type': 'Bearer', - 'access_token': 'a', - 'refresh_token': 'b', - 'expires_in': '3600', - 'expires_at': int(time.time()) + 3600, -} - - -def test_add_token_to_header(): - def assert_func(request): - token = 'Bearer ' + default_token['access_token'] - auth_header = request.headers.get('authorization') - assert auth_header == token - - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) - with OAuth2Client( - 'foo', - token=default_token, - app=mock_response - ) as client: - resp = client.get('https://i.b') - - data = resp.json() - assert data['a'] == 'a' - - -def test_add_token_to_body(): - def assert_func(request): - content = request.data - content = content.decode() - assert content == 'access_token=%s' % default_token['access_token'] - - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) - with OAuth2Client( - 'foo', - token=default_token, - token_placement='body', - app=mock_response - ) as client: - resp = client.get('https://i.b') - - data = resp.json() - assert data['a'] == 'a' - - -def test_add_token_to_uri(): - def assert_func(request): - assert default_token['access_token'] in str(request.url) - - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) - with OAuth2Client( - 'foo', - token=default_token, - token_placement='uri', - app=mock_response - ) as client: - resp = client.get('https://i.b') - - data = resp.json() - assert data['a'] == 'a' - - -def test_create_authorization_url(): - url = 'https://example.com/authorize?foo=bar' - - sess = OAuth2Client(client_id='foo') - auth_url, state = sess.create_authorization_url(url) - assert state in auth_url - assert 'client_id=foo' in auth_url - assert 'response_type=code' in auth_url - - sess = OAuth2Client(client_id='foo', prompt='none') - auth_url, state = sess.create_authorization_url( - url, state='foo', redirect_uri='https://i.b', scope='profile') - assert state == 'foo' - assert 'i.b' in auth_url - assert 'profile' in auth_url - assert 'prompt=none' in auth_url - - -def test_code_challenge(): - sess = OAuth2Client('foo', code_challenge_method='S256') - - url = 'https://example.com/authorize' - auth_url, _ = sess.create_authorization_url( - url, code_verifier=generate_token(48)) - assert 'code_challenge=' in auth_url - assert 'code_challenge_method=S256' in auth_url - - -def test_token_from_fragment(): - sess = OAuth2Client('foo') - response_url = 'https://i.b/callback#' + url_encode(default_token.items()) - assert sess.token_from_fragment(response_url) == default_token - token = sess.fetch_token(authorization_response=response_url) - assert token == default_token - - -def test_fetch_token_post(): - url = 'https://example.com/token' - - def assert_func(request): - content = request.form - assert content.get('code') == 'v' - assert content.get('client_id') == 'foo' - assert content.get('grant_type') == 'authorization_code' - - mock_response = MockDispatch(default_token, assert_func=assert_func) - with OAuth2Client('foo', app=mock_response) as client: - token = client.fetch_token(url, authorization_response='https://i.b/?code=v') - assert token == default_token - - with OAuth2Client( - 'foo', - token_endpoint_auth_method='none', - app=mock_response - ) as client: - token = client.fetch_token(url, code='v') - assert token == default_token - - mock_response = MockDispatch({'error': 'invalid_request'}) - with OAuth2Client('foo', app=mock_response) as client: - with pytest.raises(OAuthError): - client.fetch_token(url) - - -def test_fetch_token_get(): - url = 'https://example.com/token' - - def assert_func(request): - url = str(request.url) - assert 'code=v' in url - assert 'client_id=' in url - assert 'grant_type=authorization_code' in url - - mock_response = MockDispatch(default_token, assert_func=assert_func) - with OAuth2Client('foo', app=mock_response) as client: - authorization_response = 'https://i.b/?code=v' - token = client.fetch_token( - url, authorization_response=authorization_response, method='GET') - assert token == default_token - - with OAuth2Client( - 'foo', - token_endpoint_auth_method='none', - app=mock_response - ) as client: - token = client.fetch_token(url, code='v', method='GET') - assert token == default_token - - token = client.fetch_token(url + '?q=a', code='v', method='GET') - assert token == default_token - - -def test_token_auth_method_client_secret_post(): - url = 'https://example.com/token' - - def assert_func(request): - content = request.form - assert content.get('code') == 'v' - assert content.get('client_id') == 'foo' - assert content.get('client_secret') == 'bar' - assert content.get('grant_type') == 'authorization_code' - - mock_response = MockDispatch(default_token, assert_func=assert_func) - with OAuth2Client( - 'foo', 'bar', - token_endpoint_auth_method='client_secret_post', - app=mock_response - ) as client: - token = client.fetch_token(url, code='v') - - assert token == default_token - - -def test_access_token_response_hook(): - url = 'https://example.com/token' - - def _access_token_response_hook(resp): - assert resp.json() == default_token - return resp - - access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) - app = MockDispatch(default_token) - with OAuth2Client('foo', token=default_token, app=app) as sess: - sess.register_compliance_hook( - 'access_token_response', - access_token_response_hook - ) - assert sess.fetch_token(url) == default_token - assert access_token_response_hook.called is True - - -def test_password_grant_type(): - url = 'https://example.com/token' - - def assert_func(request): - content = request.form - assert content.get('username') == 'v' - assert content.get('scope') == 'profile' - assert content.get('grant_type') == 'password' - - app = MockDispatch(default_token, assert_func=assert_func) - with OAuth2Client('foo', scope='profile', app=app) as sess: - token = sess.fetch_token(url, username='v', password='v') - assert token == default_token - - token = sess.fetch_token( - url, username='v', password='v', grant_type='password') - assert token == default_token - - -def test_client_credentials_type(): - url = 'https://example.com/token' - - def assert_func(request): - content = request.form - assert content.get('scope') == 'profile' - assert content.get('grant_type') == 'client_credentials' - - app = MockDispatch(default_token, assert_func=assert_func) - with OAuth2Client('foo', scope='profile', app=app) as sess: - token = sess.fetch_token(url) - assert token == default_token - - token = sess.fetch_token(url, grant_type='client_credentials') - assert token == default_token - - -def test_cleans_previous_token_before_fetching_new_one(): - now = int(time.time()) - new_token = deepcopy(default_token) - past = now - 7200 - default_token['expires_at'] = past - new_token['expires_at'] = now + 3600 - url = 'https://example.com/token' - - app = MockDispatch(new_token) - with mock.patch('time.time', lambda: now): - with OAuth2Client('foo', token=default_token, app=app) as sess: - assert sess.fetch_token(url) == new_token - - -def test_token_status(): - token = dict(access_token='a', token_type='bearer', expires_at=100) - sess = OAuth2Client('foo', token=token) - assert sess.token.is_expired() is True - - -def test_auto_refresh_token(): - - def _update_token(token, refresh_token=None, access_token=None): - assert refresh_token == 'b' - assert token == default_token - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', refresh_token='b', - token_type='bearer', expires_at=100 - ) - - app = MockDispatch(default_token) - with OAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, app=app - ) as sess: - sess.get('https://i.b/user') - assert update_token.called is True - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - with OAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, app=app - ) as sess: - with pytest.raises(OAuthError): - sess.get('https://i.b/user') - - -def test_auto_refresh_token2(): - - def _update_token(token, refresh_token=None, access_token=None): - assert access_token == 'a' - assert token == default_token - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - - app = MockDispatch(default_token) - - with OAuth2Client( - 'foo', token=old_token, - token_endpoint='https://i.b/token', - grant_type='client_credentials', - app=app, - ) as client: - client.get('https://i.b/user') - assert update_token.called is False - - with OAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, grant_type='client_credentials', - app=app, - ) as client: - client.get('https://i.b/user') - assert update_token.called is True - - -def test_auto_refresh_token3(): - def _update_token(token, refresh_token=None, access_token=None): - assert access_token == 'a' - assert token == default_token - - update_token = mock.Mock(side_effect=_update_token) - - old_token = dict( - access_token='a', - token_type='bearer', - expires_at=100 - ) - - app = MockDispatch(default_token) - - with OAuth2Client( - 'foo', token=old_token, token_endpoint='https://i.b/token', - update_token=update_token, grant_type='client_credentials', - app=app, - ) as client: - client.post('https://i.b/user', json={'foo': 'bar'}) - assert update_token.called is True - - -def test_revoke_token(): - answer = {'status': 'ok'} - app = MockDispatch(answer) - - with OAuth2Client('a', app=app) as sess: - resp = sess.revoke_token('https://i.b/token', 'hi') - assert resp.json() == answer - - resp = sess.revoke_token( - 'https://i.b/token', 'hi', - token_type_hint='access_token' - ) - assert resp.json() == answer - - -def test_request_without_token(): - with OAuth2Client('a', app=MockDispatch()) as client: - with pytest.raises(OAuthError): - client.get('https://i.b/token') diff --git a/tests/py3/test_starlette_client/test_oauth_client.py b/tests/py3/test_starlette_client/test_oauth_client.py deleted file mode 100644 index 68654fc82..000000000 --- a/tests/py3/test_starlette_client/test_oauth_client.py +++ /dev/null @@ -1,272 +0,0 @@ -import pytest -from starlette.config import Config -from starlette.requests import Request -from authlib.integrations.starlette_client import OAuth -from tests.py3.utils import AsyncPathMapDispatch -from tests.client_base import get_bearer_token - - -def test_register_remote_app(): - oauth = OAuth() - with pytest.raises(AttributeError): - assert oauth.dev.name == 'dev' - - oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - ) - assert oauth.dev.name == 'dev' - assert oauth.dev.client_id == 'dev' - - -def test_register_with_config(): - config = Config(environ={'DEV_CLIENT_ID': 'dev'}) - oauth = OAuth(config) - oauth.register('dev') - assert oauth.dev.name == 'dev' - assert oauth.dev.client_id == 'dev' - - -def test_register_with_overwrite(): - config = Config(environ={'DEV_CLIENT_ID': 'dev'}) - oauth = OAuth(config) - oauth.register('dev', client_id='not-dev', overwrite=True) - assert oauth.dev.name == 'dev' - assert oauth.dev.client_id == 'dev' - - -@pytest.mark.asyncio -async def test_oauth1_authorize(): - oauth = OAuth() - app = AsyncPathMapDispatch({ - '/request-token': {'body': 'oauth_token=foo&oauth_verifier=baz'}, - '/token': {'body': 'oauth_token=a&oauth_token_secret=b'}, - }) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - request_token_url='https://i.b/request-token', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={ - 'app': app, - } - ) - - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - resp = await client.authorize_redirect(req, 'https://b.com/bar') - assert resp.status_code == 302 - url = resp.headers.get('Location') - assert 'oauth_token=foo' in url - - req_token = req.session.get('_dev_authlib_request_token_') - assert req_token is not None - - req.scope['query_string'] = 'oauth_token=foo&oauth_verifier=baz' - token = await client.authorize_access_token(req) - assert token['oauth_token'] == 'a' - - -@pytest.mark.asyncio -async def test_oauth2_authorize(): - oauth = OAuth() - app = AsyncPathMapDispatch({ - '/token': {'body': get_bearer_token()} - }) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={ - 'app': app, - } - ) - - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - resp = await client.authorize_redirect(req, 'https://b.com/bar') - assert resp.status_code == 302 - url = resp.headers.get('Location') - assert 'state=' in url - - state = req.session.get('_dev_authlib_state_') - assert state is not None - - req_scope.update( - { - 'path': '/', - 'query_string': f'code=a&state={state}', - 'session': req.session, - } - ) - req = Request(req_scope) - token = await client.authorize_access_token(req) - assert token['access_token'] == 'a' - - -@pytest.mark.asyncio -async def test_oauth2_authorize_code_challenge(): - app = AsyncPathMapDispatch({ - '/token': {'body': get_bearer_token()} - }) - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={ - 'code_challenge_method': 'S256', - 'app': app, - }, - ) - - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - - resp = await client.authorize_redirect(req, redirect_uri='https://b.com/bar') - assert resp.status_code == 302 - - url = resp.headers.get('Location') - assert 'code_challenge=' in url - assert 'code_challenge_method=S256' in url - - state = req.session['_dev_authlib_state_'] - assert state is not None - - verifier = req.session['_dev_authlib_code_verifier_'] - assert verifier is not None - - req_scope.update( - { - 'path': '/', - 'query_string': 'code=a&state={}'.format(state).encode(), - 'session': req.session, - } - ) - req = Request(req_scope) - - token = await client.authorize_access_token(req) - assert token['access_token'] == 'a' - - -@pytest.mark.asyncio -async def test_with_fetch_token_in_register(): - async def fetch_token(request): - return {'access_token': 'dev', 'token_type': 'bearer'} - - app = AsyncPathMapDispatch({ - '/user': {'body': {'sub': '123'}} - }) - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - fetch_token=fetch_token, - client_kwargs={ - 'app': app, - } - ) - - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - resp = await client.get('/user', request=req) - assert resp.json()['sub'] == '123' - - -@pytest.mark.asyncio -async def test_with_fetch_token_in_oauth(): - async def fetch_token(name, request): - return {'access_token': 'dev', 'token_type': 'bearer'} - - app = AsyncPathMapDispatch({ - '/user': {'body': {'sub': '123'}} - }) - oauth = OAuth(fetch_token=fetch_token) - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - authorize_url='https://i.b/authorize', - client_kwargs={ - 'app': app, - } - ) - - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - resp = await client.get('/user', request=req) - assert resp.json()['sub'] == '123' - - -@pytest.mark.asyncio -async def test_request_withhold_token(): - oauth = OAuth() - app = AsyncPathMapDispatch({ - '/user': {'body': {'sub': '123'}} - }) - client = oauth.register( - "dev", - client_id="dev", - client_secret="dev", - api_base_url="https://i.b/api", - access_token_url="https://i.b/token", - authorize_url="https://i.b/authorize", - client_kwargs={ - 'app': app, - } - ) - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - resp = await client.get('/user', request=req, withhold_token=True) - assert resp.json()['sub'] == '123' - - -@pytest.mark.asyncio -async def test_oauth2_authorize_with_metadata(): - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - ) - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - with pytest.raises(RuntimeError): - await client.create_authorization_url(req) - - - app = AsyncPathMapDispatch({ - '/.well-known/openid-configuration': {'body': { - 'authorization_endpoint': 'https://i.b/authorize' - }} - }) - client = oauth.register( - 'dev2', - client_id='dev', - client_secret='dev', - api_base_url='https://i.b/api', - access_token_url='https://i.b/token', - server_metadata_url='https://i.b/.well-known/openid-configuration', - client_kwargs={ - 'app': app, - } - ) - resp = await client.authorize_redirect(req, 'https://b.com/bar') - assert resp.status_code == 302 diff --git a/tests/py3/test_starlette_client/test_user_mixin.py b/tests/py3/test_starlette_client/test_user_mixin.py deleted file mode 100644 index f9e32b569..000000000 --- a/tests/py3/test_starlette_client/test_user_mixin.py +++ /dev/null @@ -1,149 +0,0 @@ -import pytest -from starlette.requests import Request -from authlib.integrations.starlette_client import OAuth -from authlib.jose import jwk -from authlib.jose.errors import InvalidClaimError -from authlib.oidc.core.grants.util import generate_id_token -from tests.util import read_file_path -from tests.py3.utils import AsyncPathMapDispatch -from tests.client_base import get_bearer_token - - -async def run_fetch_userinfo(payload, compliance_fix=None): - oauth = OAuth() - - async def fetch_token(request): - return get_bearer_token() - - app = AsyncPathMapDispatch({ - '/userinfo': {'body': payload} - }) - - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=fetch_token, - userinfo_endpoint='https://i.b/userinfo', - userinfo_compliance_fix=compliance_fix, - client_kwargs={ - 'app': app, - } - ) - - req_scope = {'type': 'http', 'session': {}} - req = Request(req_scope) - user = await client.userinfo(request=req) - assert user.sub == '123' - - -@pytest.mark.asyncio -async def test_fetch_userinfo(): - await run_fetch_userinfo({'sub': '123'}) - - -@pytest.mark.asyncio -async def test_userinfo_compliance_fix(): - async def _fix(remote, data): - return {'sub': data['id']} - - await run_fetch_userinfo({'id': '123'}, _fix) - - -@pytest.mark.asyncio -async def test_parse_id_token(): - key = jwk.dumps('secret', 'oct', kid='f') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123'}, key, - alg='HS256', iss='https://i.b', - aud='dev', exp=3600, nonce='n', - ) - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - jwks={'keys': [key]}, - issuer='https://i.b', - id_token_signing_alg_values_supported=['HS256', 'RS256'], - ) - req_scope = {'type': 'http', 'session': {'_dev_authlib_nonce_': 'n'}} - req = Request(req_scope) - - user = await client.parse_id_token(req, token) - assert user is None - - token['id_token'] = id_token - user = await client.parse_id_token(req, token) - assert user.sub == '123' - - claims_options = {'iss': {'value': 'https://i.b'}} - user = await client.parse_id_token(req, token, claims_options) - assert user.sub == '123' - - with pytest.raises(InvalidClaimError): - claims_options = {'iss': {'value': 'https://i.c'}} - await client.parse_id_token(req, token, claims_options) - - -@pytest.mark.asyncio -async def test_runtime_error_fetch_jwks_uri(): - key = jwk.dumps('secret', 'oct', kid='f') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123'}, key, - alg='HS256', iss='https://i.b', - aud='dev', exp=3600, nonce='n', - ) - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - issuer='https://i.b', - id_token_signing_alg_values_supported=['HS256'], - ) - req_scope = {'type': 'http', 'session': {'_dev_authlib_nonce_': 'n'}} - req = Request(req_scope) - token['id_token'] = id_token - with pytest.raises(RuntimeError): - await client.parse_id_token(req, token) - - -@pytest.mark.asyncio -async def test_force_fetch_jwks_uri(): - secret_keys = read_file_path('jwks_private.json') - token = get_bearer_token() - id_token = generate_id_token( - token, {'sub': '123'}, secret_keys, - alg='RS256', iss='https://i.b', - aud='dev', exp=3600, nonce='n', - ) - - app = AsyncPathMapDispatch({ - '/jwks': {'body': read_file_path('jwks_public.json')} - }) - - oauth = OAuth() - client = oauth.register( - 'dev', - client_id='dev', - client_secret='dev', - fetch_token=get_bearer_token, - jwks_uri='https://i.b/jwks', - issuer='https://i.b', - client_kwargs={ - 'app': app, - } - ) - - req_scope = {'type': 'http', 'session': {'_dev_authlib_nonce_': 'n'}} - req = Request(req_scope) - token['id_token'] = id_token - user = await client.parse_id_token(req, token) - assert user.sub == '123' diff --git a/tests/py3/utils.py b/tests/py3/utils.py deleted file mode 100644 index 9416e100d..000000000 --- a/tests/py3/utils.py +++ /dev/null @@ -1,121 +0,0 @@ -import json -from starlette.requests import Request as ASGIRequest -from starlette.responses import Response as ASGIResponse -from werkzeug.wrappers import Request as WSGIRequest -from werkzeug.wrappers import Response as WSGIResponse - - -class AsyncMockDispatch: - def __init__(self, body=b'', status_code=200, headers=None, - assert_func=None): - if headers is None: - headers = {} - if isinstance(body, dict): - body = json.dumps(body).encode() - headers['Content-Type'] = 'application/json' - else: - if isinstance(body, str): - body = body.encode() - headers['Content-Type'] = 'application/x-www-form-urlencoded' - - self.body = body - self.status_code = status_code - self.headers = headers - self.assert_func = assert_func - - async def __call__(self, scope, receive, send): - request = ASGIRequest(scope, receive=receive) - - if self.assert_func: - await self.assert_func(request) - - response = ASGIResponse( - status_code=self.status_code, - content=self.body, - headers=self.headers, - ) - await response(scope, receive, send) - - -class AsyncPathMapDispatch: - def __init__(self, path_maps): - self.path_maps = path_maps - - async def __call__(self, scope, receive, send): - request = ASGIRequest(scope, receive=receive) - - rv = self.path_maps[request.url.path] - status_code = rv.get('status_code', 200) - body = rv.get('body') - headers = rv.get('headers', {}) - if isinstance(body, dict): - body = json.dumps(body).encode() - headers['Content-Type'] = 'application/json' - else: - if isinstance(body, str): - body = body.encode() - headers['Content-Type'] = 'application/x-www-form-urlencoded' - - response = ASGIResponse( - status_code=status_code, - content=body, - headers=headers, - ) - await response(scope, receive, send) - -class MockDispatch: - def __init__(self, body=b'', status_code=200, headers=None, - assert_func=None): - if headers is None: - headers = {} - if isinstance(body, dict): - body = json.dumps(body).encode() - headers['Content-Type'] = 'application/json' - else: - if isinstance(body, str): - body = body.encode() - headers['Content-Type'] = 'application/x-www-form-urlencoded' - - self.body = body - self.status_code = status_code - self.headers = headers - self.assert_func = assert_func - - def __call__(self, environ, start_response): - request = WSGIRequest(environ) - - if self.assert_func: - self.assert_func(request) - - response = WSGIResponse( - status=self.status_code, - response=self.body, - headers=self.headers, - ) - return response(environ, start_response) - - -class PathMapDispatch: - def __init__(self, path_maps): - self.path_maps = path_maps - - def __call__(self, environ, start_response): - request = WSGIRequest(environ) - - rv = self.path_maps[request.url.path] - status_code = rv.get('status_code', 200) - body = rv.get('body', b'') - headers = rv.get('headers', {}) - if isinstance(body, dict): - body = json.dumps(body).encode() - headers['Content-Type'] = 'application/json' - else: - if isinstance(body, str): - body = body.encode() - headers['Content-Type'] = 'application/x-www-form-urlencoded' - response = WSGIResponse( - status=status_code, - response=body, - headers=headers, - ) - return response(environ, start_response) diff --git a/tests/util.py b/tests/util.py index 4b7ff15f8..81a5e7844 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,5 +1,6 @@ -import os import json +import os + from authlib.common.encoding import to_unicode from authlib.common.urls import url_decode @@ -7,12 +8,12 @@ def get_file_path(name): - return os.path.join(ROOT, 'files', name) + return os.path.join(ROOT, "files", name) def read_file_path(name): - with open(get_file_path(name), 'r') as f: - if name.endswith('.json'): + with open(get_file_path(name)) as f: + if name.endswith(".json"): return json.load(f) return f.read() diff --git a/tox.ini b/tox.ini index a8c5a354e..ced504e99 100644 --- a/tox.ini +++ b/tox.ini @@ -1,34 +1,37 @@ [tox] +requires = + tox>=4.22 +isolated_build = True envlist = - py{27,36,37,38} - {py36,py37,py38} - {py27,py36,py37,py38}-flask - {py36,py37,py38}-django + py{310,311,312,313,314,py310} + py{310,311,312,313,314,py310}-{clients,flask,django,jose} + docs coverage [testenv] -deps = - -rrequirements-test.txt - py27: unittest2 - flask: Flask - flask: Flask-SQLAlchemy - py3: httpx==0.14.3 - py3: pytest-asyncio - py3: starlette - py3: itsdangerous - py3: werkzeug - django: Django - django: pytest-django +dependency_groups = + dev + jose: jose + clients: clients + flask: flask + django: django setenv = TESTPATH=tests/core - RCFILE=setup.cfg - py27: RCFILE=.py27conf - py3: TESTPATH=tests/py3 + jose: TESTPATH=tests/jose + clients: TESTPATH=tests/clients flask: TESTPATH=tests/flask django: TESTPATH=tests/django commands = - coverage run --rcfile={env:RCFILE} --source=authlib -p -m pytest {env:TESTPATH} + coverage run --source=authlib -p -m pytest {posargs: {env:TESTPATH}} + +[testenv:docs] +dependency_groups = + clients + docs + flask +commands = + sphinx-build --builder html --write-all --jobs auto --fail-on-warning docs build/_html [testenv:coverage] skip_install = true