import logging
import jwt
from cubicweb_api.jwt_policy import (
    JWTCookieAuthenticationPolicy,
    JWTAuthenticationPolicy,
)
from pyramid.config import Configurator

log = logging.getLogger(__name__)


def create_jwt_policy(config: Configurator, prefix="cubicweb.auth.jwt"):
    cfg = config.registry.settings
    private_key_string = prefix + ".private_key"
    if private_key_string not in cfg:
        raise KeyError(private_key_string)
    keys = (
        "private_key",
        "public_key",
        "algorithm",
        "expiration",
        "leeway",
        "http_header",
        "auth_type",
    )
    kwargs = {}
    for k in keys:
        key_path = "{}.{}".format(prefix, k)
        if key_path in cfg:
            kwargs[k] = cfg.get(key_path)
    auth_policy = JWTAuthenticationPolicy(**kwargs)
    cookie_policy = JWTCookieAuthenticationPolicy.make_from(
        auth_policy, cookie_name="CW_JWT", https_only=True, reissue_time=7200
    )
    return cookie_policy


def _request_create_token(request, principal, expiration=None, audience=None, **claims):
    return request.authentication_policy.create_token(
        principal, expiration, audience, **claims
    )


def _request_claims(request):
    try:
        return jwt.decode(
            request.cookies.get(request.authentication_policy.cookie_name),
            request.authentication_policy.private_key,
            algorithms=[request.authentication_policy.algorithm],
        )
    except Exception:
        return {}


def setup_jwt(config: Configurator):
    config.include("pyramid_jwt")
    try:
        policy = create_jwt_policy(config)
    except KeyError as e:
        log.warning(
            "Could not configure JWT policy: missing configuration key %s", str(e)
        )
    else:
        config.registry["cubicweb.authpolicy"]._policies.append(policy)
        config.add_request_method(_request_create_token, "create_jwt_token")
        config.add_request_method(_request_claims, "jwt_claims", reify=True)
        config.add_request_method(
            lambda request: policy, "authentication_policy", reify=True
        )
        log.info("JWT policy configured")