# copyright 2022 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact https://www.logilab.fr -- mailto:contact@logilab.fr
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import logging
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from typing import Callable

from cubicweb import AuthenticationError, QueryError, Unauthorized, Forbidden
from cubicweb.pyramid.core import Connection
from cubicweb.schema_exporters import JSONSchemaExporter
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPError
from pyramid.request import Request
from pyramid.response import Response
from pyramid.view import view_config, view_defaults
from rql import RQLException
from yams import ValidationError, UnknownType

from cubicweb_api.api_transaction import ApiTransactionsRepository
from cubicweb_api.auth.jwt_auth import setup_jwt
from cubicweb_api.constants import (
    API_ROUTE_NAME_PREFIX,
)
from cubicweb_api.httperrors import get_http_error, get_http_500_error
from cubicweb_api.openapi.openapi import setup_openapi
from cubicweb_api.util import get_cw_repo, get_transactions_repository

log = logging.getLogger(__name__)


class ApiRoutes(Enum):
    """
    All the available routes as listed in the openapi/openapi_template.yml file.
    """

    schema = "schema"
    rql = "rql"
    login = "login"
    current_user = "current_user"
    transaction_begin = "transaction/begin"
    transaction_execute = "transaction/execute"
    transaction_commit = "transaction/commit"
    transaction_rollback = "transaction/rollback"
    help = "help"


def get_route_name(route_name: ApiRoutes) -> str:
    """
    Generates a unique route name using the api prefix to prevent clashes with routes
    from other cubes.

    :param route_name: The route name base
    :return: The generated route name
    """
    return f"{API_ROUTE_NAME_PREFIX}{route_name.value}"


@contextmanager
def _catch_rql_errors(cnx: Connection):
    """
    Calls and returns the result of the given function.
    If an error related to RQL occurs, this will raise an HTTP 400 error

    :param func: The function to check for errors
    :return: The function's result if no error occurred
    :raise HTTPError: with 400 code if a RQL related exception is caught
    """
    try:
        try:
            yield
        except ValidationError as e:
            e.translate(cnx._)
            raise
    except (RQLException, QueryError, ValidationError, UnknownType) as e:
        log.info(e.__class__.__name__, exc_info=True)
        raise get_http_error(400, e.__class__.__name__, str(e))


def view_exception_handler(func: Callable):
    """
    Use it as a decorator for any pyramid view to catch authentication errors
    and raise HTTP 401 or 403 errors.
    It also catches any other leftover exceptions and raises an HTTP 500 error.

    :param func: The pyramid view function
    :raise HTTPError: with different error codes depending on the exception caught
    """

    def request_wrapper(*args):
        try:
            return func(*args)
        except HTTPError as e:
            # The decorated function raised its own HTTP error, simply forward it.
            return e
        except (AuthenticationError, Unauthorized) as e:
            # User was not authenticated, return 401 HTTP error
            log.info(e.__class__.__name__, exc_info=True)
            return get_http_error(401, e.__class__.__name__, str(e))
        except Forbidden as e:
            # User was authenticated but had insufficient privileges, return 403 HTTP error
            log.info(e.__class__.__name__, exc_info=True)
            return get_http_error(403, e.__class__.__name__, str(e))
        except Exception:
            # An exception was raised but not caught, this is a server error (HTTP 5OO)
            log.info("ServerError", exc_info=True)
            raise get_http_500_error()

    return request_wrapper


def authorized_users_only(func: Callable):
    """
    Use it as a decorator to raise an AuthenticationError if no user is detected
    and anonymous access is disabled.

    :param func: The pyramid view function
    """
    @wraps(func)
    def wrapper(self):
        if (
            self.request.authenticated_userid is not None
            or get_cw_repo(self.request).config["anonymous-user"] is not None
        ):
            return func(self)
        raise AuthenticationError
    return wrapper


@view_defaults(
    request_method="POST",
    renderer="cubicweb_api_json",
    require_csrf=False,
    openapi=True,
)
class ApiViews:
    def __init__(self, request: Request):
        self.request = request

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.schema), request_method="GET"
    )
    def schema_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        repo = get_cw_repo(self.request)
        exporter = JSONSchemaExporter()
        exported_schema = exporter.export_as_dict(repo.schema)
        return exported_schema

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.rql),
    )
    def rql_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        request_params = self.request.openapi_validated.body
        query: str = request_params["query"]
        params: dict = request_params["params"]
        with _catch_rql_errors(self.request.cw_cnx):
            return self.request.cw_cnx.execute(query, params).rows

    @view_exception_handler
    @view_config(
        route_name=get_route_name(ApiRoutes.login),
    )
    def login_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        request_params = self.request.openapi_validated.body
        login: str = request_params["login"]
        pwd: str = request_params["password"]

        repo = get_cw_repo(self.request)
        with repo.internal_cnx() as cnx:
            try:
                cwuser = repo.authenticate_user(cnx, login, password=pwd)
            except AuthenticationError:
                raise get_http_error(
                    401, "AuthenticationFailure", "Login and/or password invalid."
                )
            else:
                headers = self.request.authentication_policy.remember(
                    self.request,
                    cwuser.eid,
                    login=cwuser.login,
                    firstname=cwuser.firstname,
                    lastname=cwuser.surname,
                )
                return Response(headers=headers, status=204)

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.current_user), request_method="GET",
    )
    def current_user(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        user = self.request.cw_cnx.user
        return {"eid": user.eid, "login": user.login, "dcTitle": user.dc_title()}

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.transaction_begin),
    )
    def transaction_begin_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        transactions = get_transactions_repository(self.request)
        user = self.request.cw_cnx.user
        return transactions.begin_transaction(user)

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.transaction_execute),
    )
    def transaction_execute_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        transactions = get_transactions_repository(self.request)
        request_params = self.request.openapi_validated.body
        uuid: str = request_params["uuid"]
        query: str = request_params["query"]
        params: dict = request_params["params"]
        with _catch_rql_errors(self.request.cw_cnx):
            return transactions[uuid].execute(query, params).rows

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.transaction_commit),
    )
    def transaction_commit_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        transactions = get_transactions_repository(self.request)
        request_params = self.request.openapi_validated.body
        uuid: str = request_params["uuid"]
        with _catch_rql_errors(self.request.cw_cnx):
            return transactions[uuid].commit()

    @view_exception_handler
    @authorized_users_only
    @view_config(
        route_name=get_route_name(ApiRoutes.transaction_rollback),
    )
    def transaction_rollback_route(self):
        """
        See the openapi/openapi_template.yml file for more information on this route.
        """
        transactions = get_transactions_repository(self.request)
        request_params = self.request.openapi_validated.body
        uuid: str = request_params["uuid"]
        rollback_result = transactions[uuid].rollback()
        transactions.end_transaction(uuid)
        return rollback_result


def includeme(config: Configurator):
    setup_jwt(config)
    repo = get_cw_repo(config)
    repo.api_transactions = ApiTransactionsRepository(repo)
    setup_openapi(config)
    config.pyramid_openapi3_register_routes()
    config.scan()