Commit 0acb445f authored by Philippe Pepiot's avatar Philippe Pepiot
Browse files

Fix testing tasks creating other tasks

When a task create a new task (by calling start_async_task), _TEST_TASK was
reset during the loop on it, this was leading to a KeyError (in case of
multiple tasks) or in sub-tasks not being started.

Fix this by not overriding _TEST_TASK for each new cubicweb connection and by
consuming _TEST_TASK until there is no tasks left.
parent 0aca21e47a99
......@@ -50,7 +50,10 @@ def run_all_tasks(cnx=None):
'workflow synchronisation', DeprecationWarning,
results = {}
for task_eid in list(_TEST_TASKS):
# run all tasks and gather results.
# Tasks can create other tasks, so run them until there is no one left.
while _TEST_TASKS:
task_eid = list(_TEST_TASKS)[0]
results[task_eid] = _TEST_TASKS.pop(task_eid).delay()
if cnx is not None and celery.current_app.conf.CELERY_ALWAYS_EAGER:
......@@ -110,7 +113,7 @@ class StartCeleryTaskOp(DataOperationMixIn, Operation):
global _TEST_TASKS
if self.cnx.vreg.config.mode == 'test':
# In test mode, task should run explicitly with run_all_tasks()
_TEST_TASKS = self.cnx.transaction_data.get('celerytask', {})
_TEST_TASKS.update(self.cnx.transaction_data.get('celerytask', {}))
for eid in self.get_data():
task = self.cnx.transaction_data.get('celerytask', {}).get(eid)
......@@ -22,6 +22,7 @@ import logging
import unittest
import six
import celery
import celery.result
import mock
......@@ -247,6 +248,38 @@ class CeleryTaskTC(BaseCeleryTaskTC):
self.assertEqual({'terminate': True, 'signal': 'SIGKILL'}, kwargs)
class StartAsyncTaskTC(testlib.CubicWebTC):
def setUp(self):
super(StartAsyncTaskTC, self).setUp()
celery.current_app.conf.CELERY_ALWAYS_EAGER = True
def test_task_creating_task(self):
with self.admin_access.cnx() as cnx:
def task_a():
with self.admin_access.cnx() as admin_cnx:
start_async_task(admin_cnx, 'task_b')
return 'a'
def task_b():
return 'b'
def task_c():
return 'c'
start_async_task(cnx, 'task_a')
start_async_task(cnx, 'task_c')
results = run_all_tasks(cnx)
[r.get() for r in results.values()], ['a', 'b', 'c'])
if __name__ == '__main__':
from unittest import main
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment