Newer
Older
# -*- coding: utf-8 -*-
# copyright 2022 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact https://www.logilab.fr -- mailto:contact@logilab.fr
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union
from cubicweb import AuthenticationError, QueryError
from cubicweb.pyramid.core import CubicWebPyramidRequest
from cubicweb.schema_exporters import JSONSchemaExporter
from cubicweb.server.repository import Repository
from rql import RQLException
from yams import ValidationError, UnknownType
from pyramid.config import Configurator
from pyramid.request import Request
from pyramid.view import view_config
from pyramid.httpexceptions import HTTPError
from cubicweb_api.api_transaction import ApiTransactionsRepository
from cubicweb_api.constants import (
DEFAULT_ROUTE_PARAMS,
API_PATTERN_PREFIX,
)
from cubicweb_api.httperrors import get_http_error, get_http_500_error
from cubicweb_api.jwt_auth import setup_jwt
from cubicweb_api.openapi.openapi import register_openapi_routes
log = logging.getLogger(__name__)
class ApiRoutes(str, Enum):
schema = "schema"
rql = "rql"
login = "login"
transaction_begin = "transaction/begin"
transaction_execute = "transaction/execute"
transaction_commit = "transaction/commit"
transaction_rollback = "transaction/rollback"
help = "help"
def get_cw_request(request: Request) -> CubicWebPyramidRequest:
return request.cw_request
def get_cw_repo(req_or_conf: Union[Request, Configurator]) -> Repository:
return req_or_conf.registry["cubicweb.repository"]
def cw_view_config(route_name: str, **kwargs):
return view_config(
route_name=f"{API_PATTERN_PREFIX}{route_name}",
**dict(DEFAULT_ROUTE_PARAMS, **kwargs),
)
def view_exception_handler(func):
"""
Use it as a decorator for any pyramid view to catch AuthenticationError to raise HTTP 401
and any other leftover exceptions to raise HTTP 500.
:param func: The pyramid view function
:return:
"""
def request_wrapper(request: Request):
try:
return func(request)
except HTTPError as e:
return e
except AuthenticationError as e:
log.info(e.__class__.__name__, exc_info=True)
return get_http_error(401, e.__class__.__name__, str(e))
except Exception:
log.info("ServerError", exc_info=True)
raise get_http_500_error()
@cw_view_config(route_name=ApiRoutes.schema, request_method="GET")
def schema_route(request: Request):
Returns this instance's Schema
# TODO block this if we are not connected and anon is disabled
repo = get_cw_repo(request)
exporter = JSONSchemaExporter()
exported_schema = exporter.export_as_dict(repo.schema)
return exported_schema
@cw_view_config(route_name=ApiRoutes.rql)
def rql_route(request: Request):
Executes the given rql query
request_params = request.openapi_validated.body
query: str = request_params["query"]
params: dict = request_params["params"]
try:
result = get_cw_request(request).execute(query, params)
except (RQLException, QueryError, ValidationError, UnknownType) as e:
log.info(e.__class__.__name__, exc_info=True)
raise get_http_error(400, e.__class__.__name__, str(e))
else:
@cw_view_config(route_name=ApiRoutes.login)
def login_route(request: Request):
Tries to log in the user
request_params = request.openapi_validated.body
login: str = request_params["login"]
pwd: str = request_params["password"]
repo = get_cw_repo(request)
with repo.internal_cnx() as cnx:
try:
cwuser = repo.authenticate_user(cnx, login, password=pwd)
except AuthenticationError:
raise get_http_error(
401, "AuthenticationFailure", "Login and/or password invalid."
)
else:
request.authentication_policy.remember(
request,
cwuser.eid,
login=cwuser.login,
firstname=cwuser.firstname,
lastname=cwuser.surname,
)
request.response.status_code = 204
@cw_view_config(route_name=ApiRoutes.transaction_begin)
def transaction_begin_route(request: Request):
Starts a new transaction
transactions = get_cw_repo(request).api_transactions
user = get_cw_request(request).user
return transactions.begin_transaction(user)
@cw_view_config(route_name=ApiRoutes.transaction_execute)
def transaction_execute_route(request: Request):
Executes the given rql query as part of a transaction
transactions = get_cw_repo(request).api_transactions
request_params = request.openapi_validated.body
uuid: str = request_params["uuid"]
query: str = request_params["query"]
params: dict = request_params["params"]
result = transactions[uuid].execute(query, params)
except (RQLException, QueryError, ValidationError, UnknownType) as e:
log.info(e.__class__.__name__, exc_info=True)
raise get_http_error(400, e.__class__.__name__, str(e))
else:
@cw_view_config(route_name=ApiRoutes.transaction_commit)
def transaction_commit_route(request: Request):
Commits a transaction
transactions = get_cw_repo(request).api_transactions
request_params = request.openapi_validated.body
uuid: str = request_params["uuid"]
try:
commit_result = transactions[uuid].commit()
except (RQLException, QueryError, ValidationError, UnknownType) as e:
log.info(e.__class__.__name__, exc_info=True)
raise get_http_error(400, e.__class__.__name__, str(e))
else:
transactions[uuid].rollback()
return commit_result
@cw_view_config(route_name=ApiRoutes.transaction_rollback)
def transaction_rollback_route(request: Request):
Rollback a transaction
transactions = get_cw_repo(request).api_transactions
request_params = request.openapi_validated.body
uuid: str = request_params["uuid"]
rollback_result = transactions[uuid].rollback()
transactions.end_transaction(uuid)
return rollback_result
def includeme(config: Configurator):
repo = get_cw_repo(config)
repo.api_transactions = ApiTransactionsRepository(repo)
config.pyramid_openapi3_register_routes()
config.scan()