# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Celery task executor. """ import threading import Queue from aria.orchestrator.workflows.executor import BaseExecutor class CeleryExecutor(BaseExecutor): """ Celery task executor. """ def __init__(self, app, *args, **kwargs): super(CeleryExecutor, self).__init__(*args, **kwargs) self._app = app self._started_signaled = False self._started_queue = Queue.Queue(maxsize=1) self._tasks = {} self._results = {} self._receiver = None self._stopped = False self._receiver_thread = threading.Thread(target=self._events_receiver) self._receiver_thread.daemon = True self._receiver_thread.start() self._started_queue.get(timeout=30) def _execute(self, ctx): self._tasks[ctx.id] = ctx arguments = dict(arg.unwrapped for arg in ctx.task.arguments.itervalues()) arguments['ctx'] = ctx.context self._results[ctx.id] = self._app.send_task( ctx.operation_mapping, kwargs=arguments, task_id=ctx.task.id, queue=self._get_queue(ctx)) def close(self): self._stopped = True if self._receiver: self._receiver.should_stop = True self._receiver_thread.join() @staticmethod def _get_queue(task): return None if task else None # TODO def _events_receiver(self): with self._app.connection() as connection: self._receiver = self._app.events.Receiver(connection, handlers={ 'task-started': self._celery_task_started, 'task-succeeded': self._celery_task_succeeded, 'task-failed': self._celery_task_failed, }) for _ in self._receiver.itercapture(limit=None, timeout=None, wakeup=True): if not self._started_signaled: self._started_queue.put(True) self._started_signaled = True if self._stopped: return def _celery_task_started(self, event): self._task_started(self._tasks[event['uuid']]) def _celery_task_succeeded(self, event): task, _ = self._remove_task(event['uuid']) self._task_succeeded(task) def _celery_task_failed(self, event): task, async_result = self._remove_task(event['uuid']) try: exception = async_result.result except BaseException as e: exception = RuntimeError( 'Could not de-serialize exception of task {0} --> {1}: {2}' .format(task.name, type(e).__name__, str(e))) self._task_failed(task, exception=exception) def _remove_task(self, task_id): return self._tasks.pop(task_id), self._results.pop(task_id)