# -*- coding: utf-8 -*- # copyright 2016 LOGILAB S.A. (Paris, FRANCE), all rights reserved. # contact http://www.logilab.fr -- mailto:contact@logilab.fr # # This program is free software: you can redistribute it and/or modify it under # the terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 2.1 of the License, or (at your option) # any later version. # # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more # details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . """cubicweb-celerytask entity's classes""" import six import celery from celery.result import AsyncResult, result_from_tuple from cubicweb import NoResultError from cubicweb.entities import AnyEntity, fetch_config from cubicweb.view import EntityAdapter from cubicweb.predicates import is_instance from cubicweb.server.hook import DataOperationMixIn, Operation from cw_celerytask_helpers.filelogger import get_task_logs from cubicweb_celerytask import STATES, FINAL_STATES _ = six.text_type _TEST_TASKS = {} UNKNOWN_TASK_NAME = six.text_type('') def get_tasks(): """Return tasks to be run (for use in cubicweb test mode)""" return _TEST_TASKS.copy() def run_all_tasks(cnx): """Run all pending tasks (for use in cubicweb test mode)""" results = {} # run all tasks and gather results. # Tasks can create other tasks, so run them until there is no one left. while _TEST_TASKS: task_eid = list(_TEST_TASKS)[0] task = _TEST_TASKS.pop(task_eid) # Ensure current task id is in the scope of the current test if task.id is not None and not cnx.execute( 'Any X WHERE X is CeleryTask, X task_id %(task_id)s', {'task_id': task.freeze().id} ): continue results[task_eid] = task.delay() if celery.current_app.conf.task_always_eager: for task_eid, result in results.items(): wf = cnx.entity_from_eid(task_eid).cw_adapt_to('IWorkflowable') transition = { STATES.SUCCESS: 'finish', STATES.FAILURE: 'fail', }[result.state] comment = result.traceback if comment is not None and not isinstance(comment, six.text_type): comment = comment.decode('utf-8') wf.fire_transition(transition, comment) return results def sync_task_state(cnx, task_id, task_name): log = CeleryTaskAdapter task_id = six.text_type(task_id) result = AsyncResult(task_id) if result.state == 'PENDING': log.info('Task %s state is unknown', task_id) return try: task = cnx.find('CeleryTask', task_id=task_id).one() except NoResultError: task = cnx.create_entity('CeleryTask', task_id=task_id, task_name=task_name or UNKNOWN_TASK_NAME) log.info('Created ', task.eid, task_id) task.cw_adapt_to('ICeleryTask').sync_state(task_id, task_name) def start_async_task(cnx, task, *args, **kwargs): """Create and start a new task `task` can be either a task name, a task object or a task signature """ task_name = six.text_type(celery.signature(task).task) entity = cnx.create_entity('CeleryTask', task_name=task_name) entity.cw_adapt_to('ICeleryTask').start(task, *args, **kwargs) return entity def task_in_backend(task_id): app = celery.current_app if app.conf.task_always_eager: return False else: backend = app.backend return backend.get(backend.get_key_for_task(task_id)) is not None class StartCeleryTaskOp(DataOperationMixIn, Operation): def postcommit_event(self): global _TEST_TASKS if self.cnx.vreg.config.mode == 'test': # In test mode, task should run explicitly with run_all_tasks() _TEST_TASKS.update(self.cnx.transaction_data.get('celerytask', {})) else: for eid in self.get_data(): task = self.cnx.transaction_data.get('celerytask', {}).get(eid) if task is not None: task.delay() class CeleryTask(AnyEntity): __regid__ = 'CeleryTask' fetch_attrs, cw_fetch_order = fetch_config(('task_name',)) def dc_title(self): return self.task_name def dc_long_title(self): adapted = self.cw_adapt_to('ICeleryTask') state, finished = adapted.state, adapted.finished title = self.task_name or self._cw._('subtask') if finished: title = '%s (%s)' % (title, self._cw._(state)) return title @property def progress(self): yield self.cw_adapt_to('ICeleryTask').progress for subtask in self.reverse_parent_task: yield subtask.progress @property def parent_tasks(self): yield self for task in self.parent_task: for ptask in task.parent_tasks: yield ptask def child_tasks(self): yield self for task in self.reverse_parent_task: for ctask in task.child_tasks(): yield ctask class ICeleryTask(EntityAdapter): __regid__ = 'ICeleryTask' __abstract__ = True def start(self, name, *args, **kwargs): eid = self.entity.eid task = self.get_task(name, *args, **kwargs) self._cw.transaction_data.setdefault('celerytask', {})[eid] = task StartCeleryTaskOp.get_instance(self._cw).add_data(eid) def get_task(self, name, *args, **kwargs): """Should return a celery task / signature or None This method is run in a precommit event """ return celery.signature(name, args=args, kwargs=kwargs) def sync_state(self, task_id, task_name): """Triggered by celery-monitor""" raise NotImplementedError @property def task_id(self): raise NotImplementedError @property def task_name(self): raise NotImplementedError def revoke(self, terminate=True, signal='SIGKILL'): return celery.task.control.revoke( [self.task_id], terminate=terminate, signal=signal) @property def logs(self): return get_task_logs(self.task_id) or b'' @property def result(self): return AsyncResult(self.task_id) @property def progress(self): if celery.current_app.conf.task_always_eager: return 1. result = self.result try: if result.info and 'progress' in result.info: return result.info['progress'] except TypeError: pass if self.entity.reverse_parent_task: children = self.entity.reverse_parent_task return sum(child.cw_adapt_to('ICeleryTask').progress for child in children) / len(children) if result.state == STATES.SUCCESS: return 1. return 0. @property def state(self): return self.result.state @property def finished(self): return self.state in FINAL_STATES class CeleryTaskAdapter(ICeleryTask): """Base adapter that store task call args in the transaction""" __select__ = ICeleryTask.__select__ & is_instance('CeleryTask') def attach_task(self, task, seen, parent=None): task_id = six.text_type(task.freeze().id) if parent is None: parent = self.entity if self.entity.task_id is None: self.entity.cw_set(task_id=task_id) elif task_id not in seen: task_name = six.text_type(task.task) parent = self._cw.create_entity('CeleryTask', task_id=six.text_type(task_id), task_name=task_name, parent_task=parent) seen.add(task_id) if task.name in ('celery.chain', 'celery.group'): for subtask in task.tasks: self.attach_task(subtask, seen, parent) if task.name == 'celery.chord': self.attach_task(task.body, seen, parent) for subtask in task.tasks.tasks: self.attach_task(subtask, seen, parent) def get_task(self, name, *args, **kwargs): task = super(CeleryTaskAdapter, self).get_task( name, *args, **kwargs) self.attach_task(task, set()) return task @property def task_id(self): return self.entity.task_id @property def task_name(self): return self.entity.task_name def revoke(self, terminate=True, signal='SIGKILL'): to_revoke = set([e.task_id for e in self.entity.child_tasks()]) return celery.task.control.revoke( list(to_revoke), terminate=terminate, signal=signal) def attach_result(self, result): def tree(result, seen=None): if seen is None: seen = set() if result.parent: for r in tree(result.parent, seen): yield r for child in result.children or []: for r in tree(child, seen): yield r if isinstance(result, AsyncResult): rresult = result.result if (isinstance(rresult, dict) and "celerytask_subtasks" in rresult): subtasks = result_from_tuple( rresult["celerytask_subtasks"]) for r in tree(subtasks, seen): yield r if result.task_id not in seen: seen.add(result.task_id) yield result for asr in tree(result): task_id = six.text_type(asr.id) try: cwtask = self._cw.find('CeleryTask', task_id=task_id).one() except NoResultError: cwtask = self._cw.create_entity( 'CeleryTask', task_name=UNKNOWN_TASK_NAME, task_id=six.text_type(task_id)) self.info("Create ", cwtask.eid, task_id) if not cwtask.parent_task and self.entity is not cwtask: self.info('Set %s parent_task to %s (%s)', cwtask.task_id, self.entity, self.entity.task_id) cwtask.cw_set(parent_task=self.entity) def sync_state(self, task_id, task_name, commit=True): if (self.entity.task_name == UNKNOWN_TASK_NAME and task_name is not None): self.info('Update name to %s', self.entity.eid, task_id, task_name) self.entity.cw_set(task_name=six.text_type(task_name)) result = self.result if result.ready(): self.attach_result(result) transition = { STATES.SUCCESS: 'finish', STATES.FAILURE: 'fail', STATES.STARTED: 'start', STATES.REVOKED: 'fail', 'PROGRESS': 'start', }.get(result.state) if transition is not None: self.info(' %s', self.entity.eid, task_id, transition) wf = self.entity.cw_adapt_to('IWorkflowable') if (result.traceback is None and result.state == STATES.FAILURE and result.result is not None): comment = six.text_type(result.result) else: comment = result.traceback wf.fire_transition_if_possible(transition, comment) else: self.info(' no transition found for ' 'state %s', self.entity.eid, task_id, result.state) if commit: self._cw.commit() @property def state(self): db_state = self.entity.cw_adapt_to('IWorkflowable').state db_final_state_map = {'done': STATES.SUCCESS, 'failed': STATES.FAILURE} if db_state in db_final_state_map: return db_final_state_map[db_state] elif task_in_backend(self.task_id): return super(CeleryTaskAdapter, self).state return _('unknown state')