summaryrefslogtreecommitdiffstats
path: root/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py
diff options
context:
space:
mode:
Diffstat (limited to 'azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py')
-rw-r--r--azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py97
1 files changed, 97 insertions, 0 deletions
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py
new file mode 100644
index 0000000..a2b3513
--- /dev/null
+++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Celery task executor.
+"""
+
+import threading
+import Queue
+
+from aria.orchestrator.workflows.executor import BaseExecutor
+
+
+class CeleryExecutor(BaseExecutor):
+ """
+ Celery task executor.
+ """
+
+ def __init__(self, app, *args, **kwargs):
+ super(CeleryExecutor, self).__init__(*args, **kwargs)
+ self._app = app
+ self._started_signaled = False
+ self._started_queue = Queue.Queue(maxsize=1)
+ self._tasks = {}
+ self._results = {}
+ self._receiver = None
+ self._stopped = False
+ self._receiver_thread = threading.Thread(target=self._events_receiver)
+ self._receiver_thread.daemon = True
+ self._receiver_thread.start()
+ self._started_queue.get(timeout=30)
+
+ def _execute(self, ctx):
+ self._tasks[ctx.id] = ctx
+ arguments = dict(arg.unwrapped for arg in ctx.task.arguments.itervalues())
+ arguments['ctx'] = ctx.context
+ self._results[ctx.id] = self._app.send_task(
+ ctx.operation_mapping,
+ kwargs=arguments,
+ task_id=ctx.task.id,
+ queue=self._get_queue(ctx))
+
+ def close(self):
+ self._stopped = True
+ if self._receiver:
+ self._receiver.should_stop = True
+ self._receiver_thread.join()
+
+ @staticmethod
+ def _get_queue(task):
+ return None if task else None # TODO
+
+ def _events_receiver(self):
+ with self._app.connection() as connection:
+ self._receiver = self._app.events.Receiver(connection, handlers={
+ 'task-started': self._celery_task_started,
+ 'task-succeeded': self._celery_task_succeeded,
+ 'task-failed': self._celery_task_failed,
+ })
+ for _ in self._receiver.itercapture(limit=None, timeout=None, wakeup=True):
+ if not self._started_signaled:
+ self._started_queue.put(True)
+ self._started_signaled = True
+ if self._stopped:
+ return
+
+ def _celery_task_started(self, event):
+ self._task_started(self._tasks[event['uuid']])
+
+ def _celery_task_succeeded(self, event):
+ task, _ = self._remove_task(event['uuid'])
+ self._task_succeeded(task)
+
+ def _celery_task_failed(self, event):
+ task, async_result = self._remove_task(event['uuid'])
+ try:
+ exception = async_result.result
+ except BaseException as e:
+ exception = RuntimeError(
+ 'Could not de-serialize exception of task {0} --> {1}: {2}'
+ .format(task.name, type(e).__name__, str(e)))
+ self._task_failed(task, exception=exception)
+
+ def _remove_task(self, task_id):
+ return self._tasks.pop(task_id), self._results.pop(task_id)