Newer
Older
# -*- coding: utf-8 -*-
# copyright 2022 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact https://www.logilab.fr -- mailto:contact@logilab.fr
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union
from cubicweb import AuthenticationError, QueryError
from cubicweb.pyramid.core import CubicWebPyramidRequest
from cubicweb.schema_exporters import JSONSchemaExporter
from cubicweb.server.repository import Repository
from rql import RQLException
from yams import ValidationError, UnknownType
from pyramid.config import Configurator
from pyramid.request import Request
from pyramid.view import view_config
from pyramid.httpexceptions import HTTPError
from cubicweb_api.api_transaction import ApiTransactionsRepository
from cubicweb_api.constants import (
DEFAULT_ROUTE_PARAMS,
API_PATTERN_PREFIX,
API_ROUTE_PREFIX,
)
from cubicweb_api.httperrors import get_http_error
from cubicweb_api.jwt_auth import setup_jwt
from marshmallow_dataclass import class_schema
from cubicweb_api.openapi.openapi import register_openapi_routes
from cubicweb_api.request_params import (
RqlParams,
get_request_params,
LoginParams,
TransactionExecuteParams,
TransactionParams,
)
log = logging.getLogger(__name__)
class ApiRoutes(str, Enum):
schema = "schema"
rql = "rql"
login = "login"
transaction_begin = "transaction/begin"
transaction_execute = "transaction/execute"
transaction_commit = "transaction/commit"
transaction_rollback = "transaction/rollback"
help = "help"
def get_cw_request(request: Request) -> CubicWebPyramidRequest:
return request.cw_request
def get_cw_repo(req_or_conf: Union[Request, Configurator]) -> Repository:
return req_or_conf.registry["cubicweb.repository"]
def cw_view_config(route_name: str, **kwargs):
return view_config(
route_name=f"{API_PATTERN_PREFIX}{route_name}",
**dict(DEFAULT_ROUTE_PARAMS, **kwargs),
)
def view_exception_handler(func):
"""
Use it as a decorator for any pyramid view to catch AuthenticationError to raise HTTP 401
and any other leftover exceptions to raise HTTP 500.
:param func: The pyramid view function
:return:
"""
# Make sure we do not lose the docstrings when wrapping the function
@wraps(func)
def request_wrapper(request: Request):
try:
return func(request)
except HTTPError as e:
return e
except AuthenticationError as e:
log.info(e.__class__.__name__, exc_info=True)
return get_http_error(401, e.__class__.__name__, str(e))
except Exception:
log.info("ServerError", exc_info=True)
# Do not return error content as it could lead to security leaks
raise get_http_error(
500,
"ServerError",
"The server encountered an error. Please contact support.",
)
return request_wrapper
@cw_view_config(route_name=ApiRoutes.schema, request_method="GET")
def schema_route(request: Request):
"""
---
get:
description: Returns this instance's Schema
responses:
200:
content:
application/json:
schema: RqlParams
"""
# TODO block this if we are not connected and anon is disabled
repo = get_cw_repo(request)
exporter = JSONSchemaExporter()
exported_schema = exporter.export_as_dict(repo.schema)
return exported_schema
@cw_view_config(route_name=ApiRoutes.rql)
def rql_route(request: Request):
"""
---
get:
description: Executes the given rql query
requestBody:
content:
application/json:
schema: RqlParams
responses:
200:
content:
application/json:
schema:
type: array
items:
type: object
400:
content:
application/json:
schema: ErrorSchema
"""
schema = class_schema(RqlParams)()
request_params: RqlParams = get_request_params(request, schema)
query = request_params.query
params = request_params.params
try:
result = get_cw_request(request).execute(query, params)
except (RQLException, QueryError, ValidationError, UnknownType) as e:
log.info(e.__class__.__name__, exc_info=True)
raise get_http_error(400, e.__class__.__name__, str(e))
else:
@cw_view_config(route_name=ApiRoutes.login)
def login_route(request: Request):
"""
---
get:
description: Tries to log in the user
requestBody:
content:
application/json:
schema: LoginParams
responses:
204:
description: Token has been created and returned in set-cookie header
headers:
Set-Cookie:
description: The created JWT
schema:
type: String
401:
content:
application/json:
schema: ErrorSchema
"""
schema = class_schema(LoginParams)()
request_params: LoginParams = get_request_params(request, schema)
login = request_params.login
pwd = request_params.password
repo = get_cw_repo(request)
with repo.internal_cnx() as cnx:
try:
cwuser = repo.authenticate_user(cnx, login, password=pwd)
except AuthenticationError:
raise get_http_error(
401, "AuthenticationFailure", "Login and/or password invalid."
)
else:
request.authentication_policy.remember(
request,
cwuser.eid,
login=cwuser.login,
firstname=cwuser.firstname,
lastname=cwuser.surname,
)
request.response.status_code = 204
@cw_view_config(route_name=ApiRoutes.transaction_begin)
def transaction_begin_route(request: Request):
"""
---
get:
description: Starts a new transaction
responses:
200:
content:
application/json:
schema:
type: String
"""
transactions = get_cw_repo(request).api_transactions
user = get_cw_request(request).user
return transactions.begin_transaction(user)
@cw_view_config(route_name=ApiRoutes.transaction_execute)
def transaction_execute_route(request: Request):
"""
---
get:
description: Executes the given rql query as part of a transaction
requestBody:
content:
application/json:
schema: TransactionExecuteParams
responses:
200:
content:
application/json:
schema:
type: array
items:
type: object
400:
content:
application/json:
schema: ErrorSchema
"""
transactions = get_cw_repo(request).api_transactions
schema = class_schema(TransactionExecuteParams)()
params: TransactionExecuteParams = get_request_params(request, schema)
try:
result = transactions[params.uuid].execute(params.query, params.params)
except (RQLException, QueryError, ValidationError, UnknownType) as e:
log.info(e.__class__.__name__, exc_info=True)
raise get_http_error(400, e.__class__.__name__, str(e))
else:
@cw_view_config(route_name=ApiRoutes.transaction_commit)
def transaction_commit_route(request: Request):
"""
---
get:
description: Commits a transaction
requestBody:
content:
application/json:
schema: TransactionParams
responses:
200:
description: The transaction was successful
"""
transactions = get_cw_repo(request).api_transactions
schema = class_schema(TransactionParams)()
params: TransactionParams = get_request_params(request, schema)
uuid = params.uuid
try:
commit_result = transactions[uuid].commit()
except (RQLException, QueryError, ValidationError, UnknownType) as e:
log.info(e.__class__.__name__, exc_info=True)
raise get_http_error(400, e.__class__.__name__, str(e))
else:
transactions[uuid].rollback()
return commit_result
@cw_view_config(route_name=ApiRoutes.transaction_rollback)
def transaction_rollback_route(request: Request):
"""
---
get:
description: Rollback a transaction
requestBody:
content:
application/json:
schema: TransactionParams
responses:
200:
description: The transaction rollback was successful
"""
transactions = get_cw_repo(request).api_transactions
schema = class_schema(TransactionParams)()
params: TransactionParams = get_request_params(request, schema)
uuid = params.uuid
rollback_result = transactions[uuid].rollback()
transactions.end_transaction(uuid)
return rollback_result
def add_cw_routes(config: Configurator):
for r in ApiRoutes:
config.add_route(f"{API_PATTERN_PREFIX}{r}", f"/{r}")
def includeme(config: Configurator):
repo = get_cw_repo(config)
repo.api_transactions = ApiTransactionsRepository(repo)
config.include(add_cw_routes, route_prefix=API_ROUTE_PREFIX)