Skip to content

Commit b5dc69e

Browse files
committed
Symplify framework integrations for django and starlette
1 parent e44b54d commit b5dc69e

File tree

16 files changed

+256
-412
lines changed

16 files changed

+256
-412
lines changed

authlib/integrations/base_client/async_app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
import logging
23
from authlib.common.urls import urlparse
34
from .errors import (
@@ -73,7 +74,13 @@ async def _on_update_token(self, token, refresh_token=None, access_token=None):
7374
)
7475

7576
async def load_server_metadata(self):
76-
raise NotImplementedError()
77+
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
78+
async with self.client_cls(**self.client_kwargs) as client:
79+
resp = await client.request('GET', self._server_metadata_url, withhold_token=True)
80+
metadata = resp.json()
81+
metadata['_loaded_at'] = time.time()
82+
self.server_metadata.update(metadata)
83+
return self.server_metadata
7784

7885
async def request(self, method, url, token=None, **kwargs):
7986
metadata = await self.load_server_metadata()

authlib/integrations/base_client/async_openid.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,21 @@
66

77
class AsyncOpenIDMixin(object):
88
async def fetch_jwk_set(self, force=False):
9-
raise NotImplementedError()
9+
metadata = await self.load_server_metadata()
10+
jwk_set = metadata.get('jwks')
11+
if jwk_set and not force:
12+
return jwk_set
13+
14+
uri = metadata.get('jwks_uri')
15+
if not uri:
16+
raise RuntimeError('Missing "jwks_uri" in metadata')
17+
18+
async with self.client_cls(**self.client_kwargs) as client:
19+
resp = await client.request('GET', uri, withhold_token=True)
20+
jwk_set = resp.json()
21+
22+
self.server_metadata['jwks'] = jwk_set
23+
return jwk_set
1024

1125
async def userinfo(self, **kwargs):
1226
"""Fetch user info from ``userinfo_endpoint``."""

authlib/integrations/base_client/framework_integration.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,43 @@ def _get_cache_data(self, key):
1818
except (TypeError, ValueError):
1919
return None
2020

21-
def _clear_session_state(self, request):
21+
def _clear_session_state(self, session):
2222
now = time.time()
23-
for key in dict(request.session):
23+
for key in dict(session):
2424
if '_authlib_' in key:
2525
# TODO: remove in future
26-
request.session.pop(key)
26+
session.pop(key)
2727
elif key.startswith('_state_'):
28-
value = request.session[key]
28+
value = session[key]
2929
exp = value.get('exp')
3030
if not exp or exp < now:
31-
request.session.pop(key)
31+
session.pop(key)
3232

33-
def get_state_data(self, request, state):
33+
def get_state_data(self, session, state):
3434
key = f'_state_{self.name}_{state}'
3535
if self.cache:
3636
value = self._get_cache_data(key)
3737
else:
38-
value = request.session.get(key)
38+
value = session.get(key)
3939
if value:
4040
return value.get('data')
4141
return None
4242

43-
def set_state_data(self, request, state, data):
43+
def set_state_data(self, session, state, data):
4444
key = f'_state_{self.name}_{state}'
4545
if self.cache:
4646
self.cache.set(key, {'data': data}, self.expires_in)
4747
else:
4848
now = time.time()
49-
request.session[key] = {'data': data, 'exp': now + self.expires_in}
49+
session[key] = {'data': data, 'exp': now + self.expires_in}
5050

51-
def clear_state_data(self, request, state):
51+
def clear_state_data(self, session, state):
5252
key = f'_state_{self.name}_{state}'
5353
if self.cache:
5454
self.cache.delete(key)
5555
else:
56-
request.session.pop(key, None)
57-
self._clear_session_state(request)
56+
session.pop(key, None)
57+
self._clear_session_state(session)
5858

5959
def update_token(self, token, refresh_token=None, access_token=None):
6060
raise NotImplementedError()

authlib/integrations/base_client/sync_app.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import time
12
import logging
23
from authlib.common.urls import urlparse
34
from authlib.consts import default_user_agent
45
from authlib.common.security import generate_token
56
from .errors import (
7+
MismatchingStateError,
68
MissingRequestTokenError,
79
MissingTokenError,
810
)
@@ -204,6 +206,19 @@ def _get_oauth_client(self, **metadata):
204206
session.headers['User-Agent'] = self._user_agent
205207
return session
206208

209+
def _format_state_params(self, state_data, params):
210+
if state_data is None:
211+
raise MismatchingStateError()
212+
213+
code_verifier = state_data.get('code_verifier')
214+
if code_verifier:
215+
params['code_verifier'] = code_verifier
216+
217+
redirect_uri = state_data.get('redirect_uri')
218+
if redirect_uri:
219+
params['redirect_uri'] = redirect_uri
220+
return params
221+
207222
@staticmethod
208223
def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs):
209224
rv = {}
@@ -251,7 +266,15 @@ def request(self, method, url, token=None, **kwargs):
251266
return _http_request(self, session, method, url, token, kwargs)
252267

253268
def load_server_metadata(self):
254-
raise NotImplementedError()
269+
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
270+
with self.client_cls() as session:
271+
resp = session.get(
272+
self._server_metadata_url, withhold_token=True, **self.client_kwargs)
273+
metadata = resp.json()
274+
275+
metadata['_loaded_at'] = time.time()
276+
self.server_metadata.update(metadata)
277+
return self.server_metadata
255278

256279
def create_authorization_url(self, redirect_uri=None, **kwargs):
257280
"""Generate the authorization url and state for HTTP redirect.

authlib/integrations/base_client/sync_openid.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,21 @@
44

55
class OpenIDMixin(object):
66
def fetch_jwk_set(self, force=False):
7-
raise NotImplementedError()
7+
metadata = self.load_server_metadata()
8+
jwk_set = metadata.get('jwks')
9+
if jwk_set and not force:
10+
return jwk_set
11+
12+
uri = metadata.get('jwks_uri')
13+
if not uri:
14+
raise RuntimeError('Missing "jwks_uri" in metadata')
15+
16+
with self.client_cls() as session:
17+
resp = session.get(uri, withhold_token=True, **self.client_kwargs)
18+
jwk_set = resp.json()
19+
20+
self.server_metadata['jwks'] = jwk_set
21+
return jwk_set
822

923
def userinfo(self, **kwargs):
1024
"""Fetch user info from ``userinfo_endpoint``."""

authlib/integrations/django_client/apps.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from django.http import HttpResponseRedirect
2-
from ..base_client import OAuthError, MismatchingStateError
3-
from ..requests_client.apps import OAuth1App, OAuth2App
2+
from ..requests_client import OAuth1Session, OAuth2Session
3+
from ..base_client import (
4+
BaseApp, OAuthError,
5+
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
6+
)
47

58

69
class DjangoAppMixin(object):
710
def save_authorize_data(self, request, **kwargs):
811
state = kwargs.pop('state', None)
912
if state:
10-
self.framework.set_state_data(request, state, kwargs)
13+
self.framework.set_state_data(request.session, state, kwargs)
1114
else:
1215
raise RuntimeError('Missing state value')
1316

@@ -24,7 +27,9 @@ def authorize_redirect(self, request, redirect_uri=None, **kwargs):
2427
return HttpResponseRedirect(rv['url'])
2528

2629

27-
class DjangoOAuth1App(DjangoAppMixin, OAuth1App):
30+
class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp):
31+
client_cls = OAuth1Session
32+
2833
def authorize_access_token(self, request, **kwargs):
2934
"""Fetch access token in one step.
3035
@@ -36,7 +41,7 @@ def authorize_access_token(self, request, **kwargs):
3641
if not state:
3742
raise OAuthError(description='Missing "oauth_token" parameter')
3843

39-
data = self.framework.get_state_data(request, state)
44+
data = self.framework.get_state_data(request.session, state)
4045
if not data:
4146
raise OAuthError(description='Missing "request_token" in temporary data')
4247

@@ -46,11 +51,13 @@ def authorize_access_token(self, request, **kwargs):
4651
params['redirect_uri'] = redirect_uri
4752

4853
params.update(kwargs)
49-
self.framework.clear_state_data(request, state)
54+
self.framework.clear_state_data(request.session, state)
5055
return self.fetch_access_token(**params)
5156

5257

53-
class DjangoOAuth2App(DjangoAppMixin, OAuth2App):
58+
class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
59+
client_cls = OAuth2Session
60+
5461
def authorize_access_token(self, request, **kwargs):
5562
"""Fetch access token in one step.
5663
@@ -72,19 +79,9 @@ def authorize_access_token(self, request, **kwargs):
7279
'state': request.POST.get('state'),
7380
}
7481

75-
data = self.framework.get_state_data(request, params.get('state'))
76-
if data is None:
77-
raise MismatchingStateError()
78-
79-
code_verifier = data.get('code_verifier')
80-
if code_verifier:
81-
params['code_verifier'] = code_verifier
82-
83-
redirect_uri = data.get('redirect_uri')
84-
if redirect_uri:
85-
params['redirect_uri'] = redirect_uri
86-
params.update(kwargs)
87-
token = self.fetch_access_token(**params)
82+
state_data = self.framework.get_state_data(request.session, params.get('state'))
83+
params = self._format_state_params(state_data, params)
84+
token = self.fetch_access_token(**params, **kwargs)
8885

8986
if 'id_token' in token and 'nonce' in params:
9087
userinfo = self.parse_id_token(token, nonce=params['nonce'])
Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,51 @@
1-
# flake8: noqa
1+
from werkzeug.local import LocalProxy
2+
from .integration import FlaskIntegration, token_update
3+
from .apps import FlaskOAuth1App, FlaskOAuth2App
4+
from ..base_client import BaseOAuth, OAuthError
5+
6+
7+
class OAuth(BaseOAuth):
8+
oauth1_client_cls = FlaskOAuth1App
9+
oauth2_client_cls = FlaskOAuth2App
10+
framework_integration_cls = FlaskIntegration
11+
12+
def __init__(self, app=None, cache=None, fetch_token=None, update_token=None):
13+
super(OAuth, self).__init__(
14+
cache=cache, fetch_token=fetch_token, update_token=update_token)
15+
self.app = app
16+
if app:
17+
self.init_app(app)
18+
19+
def init_app(self, app, cache=None, fetch_token=None, update_token=None):
20+
"""Initialize lazy for Flask app. This is usually used for Flask application
21+
factory pattern.
22+
"""
23+
self.app = app
24+
if cache is not None:
25+
self.cache = cache
26+
27+
if fetch_token:
28+
self.fetch_token = fetch_token
29+
if update_token:
30+
self.update_token = update_token
31+
32+
app.extensions = getattr(app, 'extensions', {})
33+
app.extensions['authlib.integrations.flask_client'] = self
34+
35+
def create_client(self, name):
36+
if not self.app:
37+
raise RuntimeError('OAuth is not init with Flask app.')
38+
return super(OAuth, self).create_client(name)
39+
40+
def register(self, name, overwrite=False, **kwargs):
41+
self._registry[name] = (overwrite, kwargs)
42+
if self.app:
43+
return self.create_client(name)
44+
return LocalProxy(lambda: self.create_client(name))
245

3-
from .oauth_registry import OAuth
4-
from .remote_app import FlaskRemoteApp
5-
from .integration import token_update, FlaskIntegration
6-
from ..base_client import OAuthError
746

847
__all__ = [
9-
'OAuth', 'FlaskRemoteApp', 'FlaskIntegration',
48+
'OAuth', 'FlaskIntegration',
49+
'FlaskOAuth1App', 'FlaskOAuth2App',
1050
'token_update', 'OAuthError',
1151
]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from flask import redirect, request, session
2+
from ..base_client import OAuthError, MismatchingStateError
3+
from ..requests_client.apps import OAuth1App, OAuth2App
4+
5+
6+
class FlaskAppMixin(object):
7+
def save_authorize_data(self, **kwargs):
8+
state = kwargs.pop('state', None)
9+
if state:
10+
self.framework.set_state_data(session, state, kwargs)
11+
else:
12+
raise RuntimeError('Missing state value')
13+
14+
def authorize_redirect(self, redirect_uri=None, **kwargs):
15+
"""Create a HTTP Redirect for Authorization Endpoint.
16+
17+
:param redirect_uri: Callback or redirect URI for authorization.
18+
:param kwargs: Extra parameters to include.
19+
:return: A HTTP redirect response.
20+
"""
21+
rv = self.create_authorization_url(redirect_uri, **kwargs)
22+
self.save_authorize_data(redirect_uri=redirect_uri, **rv)
23+
return redirect(rv['url'])
24+
25+
26+
class FlaskOAuth1App(FlaskAppMixin, OAuth1App):
27+
def authorize_access_token(self, **kwargs):
28+
"""Fetch access token in one step.
29+
30+
:return: A token dict.
31+
"""
32+
params = request.args.to_dict(flat=True)
33+
state = params.get('oauth_token')
34+
if not state:
35+
raise OAuthError(description='Missing "oauth_token" parameter')
36+
37+
data = self.framework.get_state_data(session, state)
38+
if not data:
39+
raise OAuthError(description='Missing "request_token" in temporary data')
40+
41+
params['request_token'] = data['request_token']
42+
redirect_uri = data.get('redirect_uri')
43+
if redirect_uri:
44+
params['redirect_uri'] = redirect_uri
45+
46+
params.update(kwargs)
47+
self.framework.clear_state_data(session, state)
48+
return self.fetch_access_token(**params)
49+
50+
51+
class FlaskOAuth2App(FlaskAppMixin, OAuth2App):
52+
def authorize_access_token(self, **kwargs):
53+
"""Fetch access token in one step.
54+
55+
:return: A token dict.
56+
"""
57+
if request.method == 'GET':
58+
error = request.args.get('error')
59+
if error:
60+
description = request.args.get('error_description')
61+
raise OAuthError(error=error, description=description)
62+
63+
params = {
64+
'code': request.args['code'],
65+
'state': request.args.get('state'),
66+
}
67+
else:
68+
params = {
69+
'code': request.form['code'],
70+
'state': request.form.get('state'),
71+
}
72+
73+
data = self.framework.get_state_data(session, params.get('state'))
74+
75+
if data is None:
76+
raise MismatchingStateError()
77+
78+
code_verifier = data.get('code_verifier')
79+
if code_verifier:
80+
params['code_verifier'] = code_verifier
81+
82+
redirect_uri = data.get('redirect_uri')
83+
if redirect_uri:
84+
params['redirect_uri'] = redirect_uri
85+
86+
params.update(kwargs)
87+
token = self.fetch_access_token(**params)
88+
89+
if 'id_token' in token and 'nonce' in params:
90+
userinfo = self.parse_id_token(token, nonce=params['nonce'])
91+
token['userinfo'] = userinfo
92+
return token

0 commit comments

Comments
 (0)