diff options
Diffstat (limited to 'azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/api/task_graph.py')
-rw-r--r-- | azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/api/task_graph.py | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/api/task_graph.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/api/task_graph.py new file mode 100644 index 0000000..900a0d1 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/api/task_graph.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Task graph. +""" + +from collections import Iterable + +from networkx import DiGraph, topological_sort + +from ....utils.uuid import generate_uuid +from . import task as api_task + + +class TaskNotInGraphError(Exception): + """ + An error representing a scenario where a given task is not in the graph as expected. + """ + pass + + +def _filter_out_empty_tasks(func=None): + if func is None: + return lambda f: _filter_out_empty_tasks(func=f) + + def _wrapper(task, *tasks, **kwargs): + return func(*(t for t in (task,) + tuple(tasks) if t), **kwargs) + return _wrapper + + +class TaskGraph(object): + """ + Task graph builder. + """ + + def __init__(self, name): + self.name = name + self._id = generate_uuid(variant='uuid') + self._graph = DiGraph() + + def __repr__(self): + return '{name}(id={self._id}, name={self.name}, graph={self._graph!r})'.format( + name=self.__class__.__name__, self=self) + + @property + def id(self): + """ + ID of the graph + """ + return self._id + + # graph traversal methods + + @property + def tasks(self): + """ + Iterator over tasks in the graph. + """ + for _, data in self._graph.nodes_iter(data=True): + yield data['task'] + + def topological_order(self, reverse=False): + """ + Topological sort of the graph. + + :param reverse: whether to reverse the sort + :return: list which represents the topological sort + """ + for task_id in topological_sort(self._graph, reverse=reverse): + yield self.get_task(task_id) + + def get_dependencies(self, dependent_task): + """ + Iterates over the task's dependencies. + + :param dependent_task: task whose dependencies are requested + :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if + ``dependent_task`` is not in the graph + """ + if not self.has_tasks(dependent_task): + raise TaskNotInGraphError('Task id: {0}'.format(dependent_task.id)) + for _, dependency_id in self._graph.out_edges_iter(dependent_task.id): + yield self.get_task(dependency_id) + + def get_dependents(self, dependency_task): + """ + Iterates over the task's dependents. + + :param dependency_task: task whose dependents are requested + :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if + ``dependency_task`` is not in the graph + """ + if not self.has_tasks(dependency_task): + raise TaskNotInGraphError('Task id: {0}'.format(dependency_task.id)) + for dependent_id, _ in self._graph.in_edges_iter(dependency_task.id): + yield self.get_task(dependent_id) + + # task methods + + def get_task(self, task_id): + """ + Get a task instance that's been inserted to the graph by the task's ID. + + :param basestring task_id: task ID + :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if no task found in + the graph with the given ID + """ + if not self._graph.has_node(task_id): + raise TaskNotInGraphError('Task id: {0}'.format(task_id)) + data = self._graph.node[task_id] + return data['task'] + + @_filter_out_empty_tasks + def add_tasks(self, *tasks): + """ + Adds a task to the graph. + + :param task: task + :return: list of added tasks + :rtype: list + """ + assert all([isinstance(task, (api_task.BaseTask, Iterable)) for task in tasks]) + return_tasks = [] + + for task in tasks: + if isinstance(task, Iterable): + return_tasks += self.add_tasks(*task) + elif not self.has_tasks(task): + self._graph.add_node(task.id, task=task) + return_tasks.append(task) + + return return_tasks + + @_filter_out_empty_tasks + def remove_tasks(self, *tasks): + """ + Removes the provided task from the graph. + + :param task: task + :return: list of removed tasks + :rtype: list + """ + return_tasks = [] + + for task in tasks: + if isinstance(task, Iterable): + return_tasks += self.remove_tasks(*task) + elif self.has_tasks(task): + self._graph.remove_node(task.id) + return_tasks.append(task) + + return return_tasks + + @_filter_out_empty_tasks + def has_tasks(self, *tasks): + """ + Checks whether a task is in the graph. + + :param task: task + :return: ``True`` if all tasks are in the graph, otherwise ``False`` + :rtype: list + """ + assert all(isinstance(t, (api_task.BaseTask, Iterable)) for t in tasks) + return_value = True + + for task in tasks: + if isinstance(task, Iterable): + return_value &= self.has_tasks(*task) + else: + return_value &= self._graph.has_node(task.id) + + return return_value + + def add_dependency(self, dependent, dependency): + """ + Adds a dependency for one item (task, sequence or parallel) on another. + + The dependent will only be executed after the dependency terminates. If either of the items + is either a sequence or a parallel, multiple dependencies may be added. + + :param dependent: dependent (task, sequence or parallel) + :param dependency: dependency (task, sequence or parallel) + :return: ``True`` if the dependency between the two hadn't already existed, otherwise + ``False`` + :rtype: bool + :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if either the + dependent or dependency are tasks which are not in the graph + """ + if not (self.has_tasks(dependent) and self.has_tasks(dependency)): + raise TaskNotInGraphError() + + if self.has_dependency(dependent, dependency): + return + + if isinstance(dependent, Iterable): + for dependent_task in dependent: + self.add_dependency(dependent_task, dependency) + else: + if isinstance(dependency, Iterable): + for dependency_task in dependency: + self.add_dependency(dependent, dependency_task) + else: + self._graph.add_edge(dependent.id, dependency.id) + + def has_dependency(self, dependent, dependency): + """ + Checks whether one item (task, sequence or parallel) depends on another. + + Note that if either of the items is either a sequence or a parallel, and some of the + dependencies exist in the graph but not all of them, this method will return ``False``. + + :param dependent: dependent (task, sequence or parallel) + :param dependency: dependency (task, sequence or parallel) + :return: ``True`` if the dependency between the two exists, otherwise ``False`` + :rtype: bool + :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if either the + dependent or dependency are tasks which are not in the graph + """ + if not (dependent and dependency): + return False + elif not (self.has_tasks(dependent) and self.has_tasks(dependency)): + raise TaskNotInGraphError() + + return_value = True + + if isinstance(dependent, Iterable): + for dependent_task in dependent: + return_value &= self.has_dependency(dependent_task, dependency) + else: + if isinstance(dependency, Iterable): + for dependency_task in dependency: + return_value &= self.has_dependency(dependent, dependency_task) + else: + return_value &= self._graph.has_edge(dependent.id, dependency.id) + + return return_value + + def remove_dependency(self, dependent, dependency): + """ + Removes a dependency for one item (task, sequence or parallel) on another. + + Note that if either of the items is either a sequence or a parallel, and some of the + dependencies exist in the graph but not all of them, this method will not remove any of the + dependencies and return ``False``. + + :param dependent: dependent (task, sequence or parallel) + :param dependency: dependency (task, sequence or parallel) + :return: ``False`` if the dependency between the two hadn't existed, otherwise ``True`` + :rtype: bool + :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if either the + dependent or dependency are tasks which are not in the graph + """ + if not (self.has_tasks(dependent) and self.has_tasks(dependency)): + raise TaskNotInGraphError() + + if not self.has_dependency(dependent, dependency): + return + + if isinstance(dependent, Iterable): + for dependent_task in dependent: + self.remove_dependency(dependent_task, dependency) + elif isinstance(dependency, Iterable): + for dependency_task in dependency: + self.remove_dependency(dependent, dependency_task) + else: + self._graph.remove_edge(dependent.id, dependency.id) + + @_filter_out_empty_tasks + def sequence(self, *tasks): + """ + Creates and inserts a sequence into the graph, effectively each task i depends on i-1. + + :param tasks: iterable of dependencies + :return: provided tasks + """ + if tasks: + self.add_tasks(*tasks) + + for i in xrange(1, len(tasks)): + self.add_dependency(tasks[i], tasks[i-1]) + + return tasks |