diff options
Diffstat (limited to 'azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/process.py')
-rw-r--r-- | azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/process.py | 350 |
1 files changed, 350 insertions, 0 deletions
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/process.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/process.py new file mode 100644 index 0000000..185f15f --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/process.py @@ -0,0 +1,350 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sub-process task executor. +""" + +# pylint: disable=wrong-import-position + +import os +import sys + +# As part of the process executor implementation, subprocess are started with this module as their +# entry point. We thus remove this module's directory from the python path if it happens to be +# there + +from collections import namedtuple + +script_dir = os.path.dirname(__file__) +if script_dir in sys.path: + sys.path.remove(script_dir) + +import contextlib +import io +import threading +import socket +import struct +import subprocess +import tempfile +import Queue +import pickle + +import psutil +import jsonpickle + +import aria +from aria.orchestrator.workflows.executor import base +from aria.extension import process_executor +from aria.utils import ( + imports, + exceptions, + process as process_utils +) + + +_INT_FMT = 'I' +_INT_SIZE = struct.calcsize(_INT_FMT) +UPDATE_TRACKED_CHANGES_FAILED_STR = \ + 'Some changes failed writing to storage. For more info refer to the log.' + + +_Task = namedtuple('_Task', 'proc, ctx') + + +class ProcessExecutor(base.BaseExecutor): + """ + Sub-process task executor. + """ + + def __init__(self, plugin_manager=None, python_path=None, *args, **kwargs): + super(ProcessExecutor, self).__init__(*args, **kwargs) + self._plugin_manager = plugin_manager + + # Optional list of additional directories that should be added to + # subprocesses python path + self._python_path = python_path or [] + + # Flag that denotes whether this executor has been stopped + self._stopped = False + + # Contains reference to all currently running tasks + self._tasks = {} + + self._request_handlers = { + 'started': self._handle_task_started_request, + 'succeeded': self._handle_task_succeeded_request, + 'failed': self._handle_task_failed_request, + } + + # Server socket used to accept task status messages from subprocesses + self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_socket.bind(('localhost', 0)) + self._server_socket.listen(10) + self._server_port = self._server_socket.getsockname()[1] + + # Used to send a "closed" message to the listener when this executor is closed + self._messenger = _Messenger(task_id=None, port=self._server_port) + + # Queue object used by the listener thread to notify this constructed it has started + # (see last line of this __init__ method) + self._listener_started = Queue.Queue() + + # Listener thread to handle subprocesses task status messages + self._listener_thread = threading.Thread(target=self._listener) + self._listener_thread.daemon = True + self._listener_thread.start() + + # Wait for listener thread to actually start before returning + self._listener_started.get(timeout=60) + + def close(self): + if self._stopped: + return + self._stopped = True + # Listener thread may be blocked on "accept" call. This will wake it up with an explicit + # "closed" message + self._messenger.closed() + self._server_socket.close() + self._listener_thread.join(timeout=60) + + # we use set(self._tasks) since tasks may change in the process of closing + for task_id in set(self._tasks): + self.terminate(task_id) + + def terminate(self, task_id): + task = self._remove_task(task_id) + # The process might have managed to finish, thus it would not be in the tasks list + if task: + try: + parent_process = psutil.Process(task.proc.pid) + for child_process in reversed(parent_process.children(recursive=True)): + try: + child_process.kill() + except BaseException: + pass + parent_process.kill() + except BaseException: + pass + + def _execute(self, ctx): + self._check_closed() + + # Temporary file used to pass arguments to the started subprocess + file_descriptor, arguments_json_path = tempfile.mkstemp(prefix='executor-', suffix='.json') + os.close(file_descriptor) + with open(arguments_json_path, 'wb') as f: + f.write(pickle.dumps(self._create_arguments_dict(ctx))) + + env = self._construct_subprocess_env(task=ctx.task) + # Asynchronously start the operation in a subprocess + proc = subprocess.Popen( + [ + sys.executable, + os.path.expanduser(os.path.expandvars(__file__)), + os.path.expanduser(os.path.expandvars(arguments_json_path)) + ], + env=env) + + self._tasks[ctx.task.id] = _Task(ctx=ctx, proc=proc) + + def _remove_task(self, task_id): + return self._tasks.pop(task_id, None) + + def _check_closed(self): + if self._stopped: + raise RuntimeError('Executor closed') + + def _create_arguments_dict(self, ctx): + return { + 'task_id': ctx.task.id, + 'function': ctx.task.function, + 'operation_arguments': dict(arg.unwrapped for arg in ctx.task.arguments.itervalues()), + 'port': self._server_port, + 'context': ctx.serialization_dict + } + + def _construct_subprocess_env(self, task): + env = os.environ.copy() + + if task.plugin_fk and self._plugin_manager: + # If this is a plugin operation, + # load the plugin on the subprocess env we're constructing + self._plugin_manager.load_plugin(task.plugin, env=env) + + # Add user supplied directories to injected PYTHONPATH + if self._python_path: + process_utils.append_to_pythonpath(*self._python_path, env=env) + + return env + + def _listener(self): + # Notify __init__ method this thread has actually started + self._listener_started.put(True) + while not self._stopped: + try: + with self._accept_request() as (request, response): + request_type = request['type'] + if request_type == 'closed': + break + request_handler = self._request_handlers.get(request_type) + if not request_handler: + raise RuntimeError('Invalid request type: {0}'.format(request_type)) + task_id = request['task_id'] + request_handler(task_id=task_id, request=request, response=response) + except BaseException as e: + self.logger.debug('Error in process executor listener: {0}'.format(e)) + + @contextlib.contextmanager + def _accept_request(self): + with contextlib.closing(self._server_socket.accept()[0]) as connection: + message = _recv_message(connection) + response = {} + try: + yield message, response + except BaseException as e: + response['exception'] = exceptions.wrap_if_needed(e) + raise + finally: + _send_message(connection, response) + + def _handle_task_started_request(self, task_id, **kwargs): + self._task_started(self._tasks[task_id].ctx) + + def _handle_task_succeeded_request(self, task_id, **kwargs): + task = self._remove_task(task_id) + if task: + self._task_succeeded(task.ctx) + + def _handle_task_failed_request(self, task_id, request, **kwargs): + task = self._remove_task(task_id) + if task: + self._task_failed( + task.ctx, exception=request['exception'], traceback=request['traceback']) + + +def _send_message(connection, message): + + # Packing the length of the entire msg using struct.pack. + # This enables later reading of the content. + def _pack(data): + return struct.pack(_INT_FMT, len(data)) + + data = jsonpickle.dumps(message) + msg_metadata = _pack(data) + connection.send(msg_metadata) + connection.sendall(data) + + +def _recv_message(connection): + # Retrieving the length of the msg to come. + def _unpack(conn): + return struct.unpack(_INT_FMT, _recv_bytes(conn, _INT_SIZE))[0] + + msg_metadata_len = _unpack(connection) + msg = _recv_bytes(connection, msg_metadata_len) + return jsonpickle.loads(msg) + + +def _recv_bytes(connection, count): + result = io.BytesIO() + while True: + if not count: + return result.getvalue() + read = connection.recv(count) + if not read: + return result.getvalue() + result.write(read) + count -= len(read) + + +class _Messenger(object): + + def __init__(self, task_id, port): + self.task_id = task_id + self.port = port + + def started(self): + """Task started message""" + self._send_message(type='started') + + def succeeded(self): + """Task succeeded message""" + self._send_message(type='succeeded') + + def failed(self, exception): + """Task failed message""" + self._send_message(type='failed', exception=exception) + + def closed(self): + """Executor closed message""" + self._send_message(type='closed') + + def _send_message(self, type, exception=None): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('localhost', self.port)) + try: + _send_message(sock, { + 'type': type, + 'task_id': self.task_id, + 'exception': exceptions.wrap_if_needed(exception), + 'traceback': exceptions.get_exception_as_string(*sys.exc_info()), + }) + response = _recv_message(sock) + response_exception = response.get('exception') + if response_exception: + raise response_exception + finally: + sock.close() + + +def _main(): + arguments_json_path = sys.argv[1] + with open(arguments_json_path) as f: + arguments = pickle.loads(f.read()) + + # arguments_json_path is a temporary file created by the parent process. + # so we remove it here + os.remove(arguments_json_path) + + task_id = arguments['task_id'] + port = arguments['port'] + messenger = _Messenger(task_id=task_id, port=port) + + function = arguments['function'] + operation_arguments = arguments['operation_arguments'] + context_dict = arguments['context'] + + try: + ctx = context_dict['context_cls'].instantiate_from_dict(**context_dict['context']) + except BaseException as e: + messenger.failed(e) + return + + try: + messenger.started() + task_func = imports.load_attribute(function) + aria.install_aria_extensions() + for decorate in process_executor.decorate(): + task_func = decorate(task_func) + task_func(ctx=ctx, **operation_arguments) + ctx.close() + messenger.succeeded() + except BaseException as e: + ctx.close() + messenger.failed(e) + +if __name__ == '__main__': + _main() |