Skip to content
Snippets Groups Projects
Commit d9a956e51cc4 authored by Arnaud Vergnet's avatar Arnaud Vergnet :sun_with_face:
Browse files

feat(jwt): use pyramid_jwt as local module to alter its behavior

This allows us to change how cookies are created
parent 794079ce3b96
No related branches found
No related tags found
1 merge request!6feat: add jwt authentication
...@@ -14,10 +14,7 @@ ...@@ -14,10 +14,7 @@
description = "This cube is the new api which will be integrated in CubicWeb 4." description = "This cube is the new api which will be integrated in CubicWeb 4."
web = "https://forge.extranet.logilab.fr/cubicweb/cubes/api" web = "https://forge.extranet.logilab.fr/cubicweb/cubes/api"
__depends__ = { __depends__ = {"cubicweb": ">= 3.36.0", "PyJWT": ">= 2.4.0"}
"cubicweb": ">= 3.36.0",
"pyramid-jwt": None,
}
__recommends__ = {} __recommends__ = {}
classifiers = [ classifiers = [
......
import logging import logging
import jwt import jwt
from pyramid_jwt import JWTCookieAuthenticationPolicy, create_jwt_authentication_policy from cubicweb_api.jwt_policy import (
JWTCookieAuthenticationPolicy,
JWTAuthenticationPolicy,
)
from pyramid.config import Configurator from pyramid.config import Configurator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -4,9 +7,9 @@ ...@@ -4,9 +7,9 @@
from pyramid.config import Configurator from pyramid.config import Configurator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def create_cubicweb_jwt_cookie_policy(config: Configurator, prefix="cubicweb.auth.jwt"): def create_jwt_policy(config: Configurator, prefix="cubicweb.auth.jwt"):
cfg = config.registry.settings cfg = config.registry.settings
private_key_string = prefix + ".private_key" private_key_string = prefix + ".private_key"
if private_key_string not in cfg: if private_key_string not in cfg:
...@@ -20,8 +23,12 @@ ...@@ -20,8 +23,12 @@
"http_header", "http_header",
"auth_type", "auth_type",
) )
kwargs = {k: cfg.get("{}.{}".format(prefix, k), None) for k in keys} kwargs = {}
auth_policy = create_jwt_authentication_policy(config, **kwargs) for k in keys:
key_path = "{}.{}".format(prefix, k)
if key_path in cfg:
kwargs[k] = cfg.get(key_path)
auth_policy = JWTAuthenticationPolicy(**kwargs)
cookie_policy = JWTCookieAuthenticationPolicy.make_from( cookie_policy = JWTCookieAuthenticationPolicy.make_from(
auth_policy, cookie_name="CW_JWT", https_only=True, reissue_time=7200 auth_policy, cookie_name="CW_JWT", https_only=True, reissue_time=7200
) )
...@@ -48,7 +55,7 @@ ...@@ -48,7 +55,7 @@
def setup_jwt(config: Configurator): def setup_jwt(config: Configurator):
config.include("pyramid_jwt") config.include("pyramid_jwt")
try: try:
policy = create_cubicweb_jwt_cookie_policy(config) policy = create_jwt_policy(config)
except KeyError as e: except KeyError as e:
log.warning( log.warning(
"Could not configure JWT policy: missing configuration key %s", str(e) "Could not configure JWT policy: missing configuration key %s", str(e)
......
import datetime
import logging
import time
import warnings
from json import JSONEncoder
import jwt
from pyramid.renderers import JSON
from webob.cookies import CookieProfile
from zope.interface import implementer
from pyramid.authentication import CallbackAuthenticationPolicy
from pyramid.interfaces import IAuthenticationPolicy, IRendererFactory
from pyramid.request import Request
log = logging.getLogger(__name__)
marker = []
# Adapted from https://github.com/wichert/pyramid_jwt
# Cookie creation was rewritten because the previous method would base64 encode the JWT
class PyramidJSONEncoderFactory(JSON):
def __init__(self, pyramid_registry=None, **kw):
super().__init__(**kw)
self.registry = pyramid_registry
def __call__(self, *args, **kwargs):
json_renderer = None
if self.registry is not None:
json_renderer = self.registry.queryUtility(
IRendererFactory, "json", default=JSONEncoder
)
request = kwargs.get("request")
if not kwargs.get("default") and isinstance(json_renderer, JSON):
self.components = json_renderer.components
kwargs["default"] = self._make_default(request)
return JSONEncoder(*args, **kwargs)
json_encoder_factory = PyramidJSONEncoderFactory(None)
@implementer(IAuthenticationPolicy)
class JWTAuthenticationPolicy(CallbackAuthenticationPolicy):
def __init__(
self,
private_key,
public_key=None,
algorithm="HS512",
leeway=0,
expiration=None,
default_claims=None,
http_header="Authorization",
auth_type="JWT",
callback=None,
json_encoder=None,
audience=None,
):
self.private_key = private_key
self.public_key = public_key if public_key is not None else private_key
self.algorithm = algorithm
self.leeway = leeway
self.default_claims = default_claims if default_claims else {}
self.http_header = http_header
self.auth_type = auth_type
if expiration:
if not isinstance(expiration, datetime.timedelta):
expiration = datetime.timedelta(seconds=expiration)
self.expiration = expiration
else:
self.expiration = None
if audience:
self.audience = audience
else:
self.audience = None
self.callback = callback
if json_encoder is None:
json_encoder = json_encoder_factory
self.json_encoder = json_encoder
self.jwt_std_claims = ("sub", "iat", "exp", "aud")
def create_token(self, principal, expiration=None, audience=None, **claims):
payload = self.default_claims.copy()
payload.update(claims)
payload["sub"] = principal
payload["iat"] = iat = datetime.datetime.utcnow()
expiration = expiration or self.expiration
audience = audience or self.audience
if expiration:
if not isinstance(expiration, datetime.timedelta):
expiration = datetime.timedelta(seconds=expiration)
payload["exp"] = iat + expiration
if audience:
payload["aud"] = audience
token = jwt.encode(
payload,
self.private_key,
algorithm=self.algorithm,
json_encoder=self.json_encoder,
)
if not isinstance(token, str): # Python3 unicode madness
token = token.decode("ascii")
return token
def get_claims(self, request: Request):
if self.http_header == "Authorization":
try:
if request.authorization is None:
return {}
except ValueError: # Invalid Authorization header
return {}
(auth_type, token) = request.authorization
if auth_type != self.auth_type:
return {}
else:
token = request.headers.get(self.http_header)
if not token:
return {}
return self.jwt_decode(request, token)
def jwt_decode(self, request: Request, token: str):
try:
claims = jwt.decode(
token,
self.public_key,
algorithms=[self.algorithm],
leeway=self.leeway,
audience=self.audience,
)
return claims
except jwt.InvalidTokenError as e:
log.warning("Invalid JWT token from %s: %s", request.remote_addr, e)
return {}
def unauthenticated_userid(self, request: Request):
return request.jwt_claims.get("sub")
def remember(self, request: Request, principal, **kw):
warnings.warn(
"JWT tokens need to be returned by an API. Using remember() "
"has no effect.",
stacklevel=3,
)
return []
def forget(self, request: Request):
warnings.warn(
"JWT tokens are managed by API (users) manually. Using forget() "
"has no effect.",
stacklevel=3,
)
return []
class ReissueError(Exception):
pass
@implementer(IAuthenticationPolicy)
class JWTCookieAuthenticationPolicy(JWTAuthenticationPolicy):
def __init__(
self,
private_key,
public_key=None,
algorithm="HS512",
leeway=0,
expiration=None,
default_claims=None,
http_header="Authorization",
auth_type="JWT",
callback=None,
json_encoder=None,
audience=None,
cookie_name=None,
https_only=True,
reissue_time=None,
cookie_path=None,
):
super(JWTCookieAuthenticationPolicy, self).__init__(
private_key,
public_key,
algorithm,
leeway,
expiration,
default_claims,
http_header,
auth_type,
callback,
json_encoder,
audience,
)
self.https_only = https_only
self.cookie_name = cookie_name or "Authorization"
self.max_age = self.expiration and self.expiration.total_seconds()
self.cookie_path = cookie_path
if reissue_time and isinstance(reissue_time, datetime.timedelta):
reissue_time = reissue_time.total_seconds()
self.reissue_time = reissue_time
self.cookie_profile = CookieProfile(
cookie_name=self.cookie_name,
secure=self.https_only,
max_age=self.max_age,
httponly=True,
path=cookie_path,
)
@staticmethod
def make_from(policy, **kwargs):
if not isinstance(policy, JWTAuthenticationPolicy):
pol_type = policy.__class__.__name__
raise ValueError("Invalid policy type %s" % pol_type)
return JWTCookieAuthenticationPolicy(
private_key=policy.private_key,
public_key=policy.public_key,
algorithm=policy.algorithm,
leeway=policy.leeway,
expiration=policy.expiration,
default_claims=policy.default_claims,
http_header=policy.http_header,
auth_type=policy.auth_type,
callback=policy.callback,
json_encoder=policy.json_encoder,
audience=policy.audience,
**kwargs
)
def remember(self, request: Request, principal, **kw):
token = self.create_token(principal, self.expiration, self.audience, **kw)
if hasattr(request, "_jwt_cookie_reissued"):
request._jwt_cookie_reissue_revoked = True
request.response.set_cookie(
self.cookie_name,
token,
secure=self.https_only,
max_age=self.max_age,
httponly=True,
path=self.cookie_path,
domain=request.domain,
)
def forget(self, request: Request):
request._jwt_cookie_reissue_revoked = True
request.response.set_cookie(self.cookie_name, None)
def get_claims(self, request: Request):
# FIXME does not work. Store the token instead of using the cookie profile
profile = self.cookie_profile.bind(request)
cookie = profile.get_value()
reissue = self.reissue_time is not None
if cookie is None:
return {}
claims = self.jwt_decode(request, cookie)
if reissue and not hasattr(request, "_jwt_cookie_reissued"):
self._handle_reissue(request, claims)
return claims
def _handle_reissue(self, request: Request, claims: dict):
if not request or not claims:
raise ValueError("Cannot handle JWT reissue: insufficient arguments")
if "iat" not in claims:
raise ReissueError("Token claim's is missing IAT")
if "sub" not in claims:
raise ReissueError("Token claim's is missing SUB")
token_dt = claims["iat"]
principal = claims["sub"]
now = time.time()
if now < token_dt + self.reissue_time:
# Token not yet eligible for reissuing
return
extra_claims = dict(
filter(lambda item: item[0] not in self.jwt_std_claims, claims.items())
)
self.remember(request, principal, **extra_claims)
def reissue_jwt_cookie(re_request: Request):
if not hasattr(re_request, "_jwt_cookie_reissue_revoked"):
self.remember(re_request, principal, **extra_claims)
request.add_response_callback(reissue_jwt_cookie)
request._jwt_cookie_reissued = True
...@@ -60,14 +60,14 @@ ...@@ -60,14 +60,14 @@
with repo.internal_cnx() as cnx: with repo.internal_cnx() as cnx:
cwuser = repo.authenticate_user(cnx, login, password=pwd) cwuser = repo.authenticate_user(cnx, login, password=pwd)
return { self.request.authentication_policy.remember(
"token": self.request.create_jwt_token( self.request,
cwuser.eid, cwuser.eid,
login=cwuser.login, login=cwuser.login,
firstname=cwuser.firstname, firstname=cwuser.firstname,
lastname=cwuser.surname, lastname=cwuser.surname,
) )
} self.request.response.status_code = 204
class TransactionContext(Context): class TransactionContext(Context):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment