diff options
Diffstat (limited to 'azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin')
14 files changed, 1472 insertions, 0 deletions
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/__init__.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/__init__.py new file mode 100644 index 0000000..d15de99 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/__init__.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Execution plugin package. +""" + +from contextlib import contextmanager +from . import instantiation + + +# Populated during execution of python scripts +ctx = None +inputs = None + + +@contextmanager +def python_script_scope(operation_ctx, operation_inputs): + global ctx + global inputs + try: + ctx = operation_ctx + inputs = operation_inputs + yield + finally: + ctx = None + inputs = None diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/common.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/common.py new file mode 100644 index 0000000..ce6746c --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/common.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Execution plugin utilities. +""" + +import json +import os +import tempfile + +import requests + +from . import constants +from . import exceptions + + +def is_windows(): + return os.name == 'nt' + + +def download_script(ctx, script_path): + split = script_path.split('://') + schema = split[0] + suffix = script_path.split('/')[-1] + file_descriptor, dest_script_path = tempfile.mkstemp(suffix='-{0}'.format(suffix)) + os.close(file_descriptor) + try: + if schema in ('http', 'https'): + response = requests.get(script_path) + if response.status_code == 404: + ctx.task.abort('Failed to download script: {0} (status code: {1})' + .format(script_path, response.status_code)) + content = response.text + with open(dest_script_path, 'wb') as f: + f.write(content) + else: + ctx.download_resource(destination=dest_script_path, path=script_path) + except: + os.remove(dest_script_path) + raise + return dest_script_path + + +def create_process_config(script_path, process, operation_kwargs, quote_json_env_vars=False): + """ + Updates a process with its environment variables, and return it. + + Gets a dict representing a process and a dict representing the environment variables. Converts + each environment variable to a format of:: + + <string representing the name of the variable>: + <json formatted string representing the value of the variable>. + + Finally, updates the process with the newly formatted environment variables, and return the + process. + + :param process: dict representing a process + :type process: dict + :param operation_kwargs: dict representing environment variables that should exist in the + process's running environment. + :type operation_kwargs: dict + :return: process updated with its environment variables + :rtype: dict + """ + process = process or {} + env_vars = operation_kwargs.copy() + if 'ctx' in env_vars: + del env_vars['ctx'] + env_vars.update(process.get('env', {})) + for k, v in env_vars.items(): + if isinstance(v, (dict, list, tuple, bool, int, float)): + v = json.dumps(v) + if quote_json_env_vars: + v = "'{0}'".format(v) + if is_windows(): + # These <k,v> environment variables will subsequently + # be used in a subprocess.Popen() call, as the `env` parameter. + # In some windows python versions, if an environment variable + # name is not of type str (e.g. unicode), the Popen call will + # fail. + k = str(k) + # The windows shell removes all double quotes - escape them + # to still be able to pass JSON in env vars to the shell. + v = v.replace('"', '\\"') + del env_vars[k] + env_vars[k] = str(v) + process['env'] = env_vars + args = process.get('args') + command = script_path + command_prefix = process.get('command_prefix') + if command_prefix: + command = '{0} {1}'.format(command_prefix, command) + if args: + command = ' '.join([command] + [str(a) for a in args]) + process['command'] = command + return process + + +def patch_ctx(ctx): + ctx._error = None + task = ctx.task + + def _validate_legal_action(): + if ctx._error is not None: + ctx._error = RuntimeError(constants.ILLEGAL_CTX_OPERATION_MESSAGE) + raise ctx._error + + def abort_operation(message=None): + _validate_legal_action() + ctx._error = exceptions.ScriptException(message=message, retry=False) + return ctx._error + task.abort = abort_operation + + def retry_operation(message=None, retry_interval=None): + _validate_legal_action() + ctx._error = exceptions.ScriptException(message=message, + retry=True, + retry_interval=retry_interval) + return ctx._error + task.retry = retry_operation + + +def check_error(ctx, error_check_func=None, reraise=False): + _error = ctx._error + # this happens when a script calls task.abort/task.retry more than once + if isinstance(_error, RuntimeError): + ctx.task.abort(str(_error)) + # ScriptException is populated by the ctx proxy server when task.abort or task.retry + # are called + elif isinstance(_error, exceptions.ScriptException): + if _error.retry: + ctx.task.retry(_error.message, _error.retry_interval) + else: + ctx.task.abort(_error.message) + # local and ssh operations may pass an additional logic check for errors here + if error_check_func: + error_check_func() + # if this function is called from within an ``except`` clause, a re-raise maybe required + if reraise: + raise # pylint: disable=misplaced-bare-raise + return _error diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/constants.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/constants.py new file mode 100644 index 0000000..1953912 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/constants.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Execution plugin constants. +""" +import os +import tempfile + +from . import exceptions + +# related to local +PYTHON_SCRIPT_FILE_EXTENSION = '.py' +POWERSHELL_SCRIPT_FILE_EXTENSION = '.ps1' +DEFAULT_POWERSHELL_EXECUTABLE = 'powershell' + +# related to both local and ssh +ILLEGAL_CTX_OPERATION_MESSAGE = 'ctx may only abort or retry once' + +# related to ssh +DEFAULT_BASE_DIR = os.path.join(tempfile.gettempdir(), 'aria-ctx') +FABRIC_ENV_DEFAULTS = { + 'connection_attempts': 5, + 'timeout': 10, + 'forward_agent': False, + 'abort_on_prompts': True, + 'keepalive': 0, + 'linewise': False, + 'pool_size': 0, + 'skip_bad_hosts': False, + 'status': False, + 'disable_known_hosts': True, + 'combine_stderr': True, + 'abort_exception': exceptions.TaskException, +} +VALID_FABRIC_GROUPS = set([ + 'status', + 'aborts', + 'warnings', + 'running', + 'stdout', + 'stderr', + 'user', + 'everything' +]) diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/__init__.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/__init__.py new file mode 100644 index 0000000..46c8cf1 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``ctx`` proxy. +""" + +from . import server, client diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/client.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/client.py new file mode 100644 index 0000000..84d66f1 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/client.py @@ -0,0 +1,114 @@ +#! /usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``ctx`` proxy client implementation. +""" + +import argparse +import json +import os +import sys +import urllib2 + + +# Environment variable for the socket url (used by clients to locate the socket) +CTX_SOCKET_URL = 'CTX_SOCKET_URL' + + +class _RequestError(RuntimeError): + + def __init__(self, ex_message, ex_type, ex_traceback): + super(_RequestError, self).__init__(self, '{0}: {1}'.format(ex_type, ex_message)) + self.ex_type = ex_type + self.ex_message = ex_message + self.ex_traceback = ex_traceback + + +def _http_request(socket_url, request, method, timeout): + opener = urllib2.build_opener(urllib2.HTTPHandler) + request = urllib2.Request(socket_url, data=json.dumps(request)) + request.get_method = lambda: method + response = opener.open(request, timeout=timeout) + + if response.code != 200: + raise RuntimeError('Request failed: {0}'.format(response)) + return json.loads(response.read()) + + +def _client_request(socket_url, args, timeout, method='POST'): + response = _http_request( + socket_url=socket_url, + request={'args': args}, + method=method, + timeout=timeout + ) + payload = response.get('payload') + response_type = response.get('type') + if response_type == 'error': + ex_type = payload['type'] + ex_message = payload['message'] + ex_traceback = payload['traceback'] + raise _RequestError(ex_message, ex_type, ex_traceback) + elif response_type == 'stop_operation': + raise SystemExit(payload['message']) + else: + return payload + + +def _parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--timeout', type=int, default=30) + parser.add_argument('--socket-url', default=os.environ.get(CTX_SOCKET_URL)) + parser.add_argument('--json-arg-prefix', default='@') + parser.add_argument('-j', '--json-output', action='store_true') + parser.add_argument('args', nargs='*') + args = parser.parse_args(args=args) + if not args.socket_url: + raise RuntimeError('Missing CTX_SOCKET_URL environment variable ' + 'or socket_url command line argument. (ctx is supposed to be executed ' + 'within an operation context)') + return args + + +def _process_args(json_prefix, args): + processed_args = [] + for arg in args: + if arg.startswith(json_prefix): + arg = json.loads(arg[1:]) + processed_args.append(arg) + return processed_args + + +def main(args=None): + args = _parse_args(args) + response = _client_request( + args.socket_url, + args=_process_args(args.json_arg_prefix, args.args), + timeout=args.timeout) + if args.json_output: + response = json.dumps(response) + else: + if response is None: + response = '' + try: + response = str(response) + except UnicodeEncodeError: + response = unicode(response).encode('utf8') + sys.stdout.write(response) + +if __name__ == '__main__': + main() diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/server.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/server.py new file mode 100644 index 0000000..91b95d9 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ctx_proxy/server.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``ctx`` proxy server implementation. +""" + +import json +import socket +import Queue +import StringIO +import threading +import traceback +import wsgiref.simple_server + +import bottle +from aria import modeling + +from .. import exceptions + + +class CtxProxy(object): + + def __init__(self, ctx, ctx_patcher=(lambda *args, **kwargs: None)): + self.ctx = ctx + self._ctx_patcher = ctx_patcher + self.port = _get_unused_port() + self.socket_url = 'http://localhost:{0}'.format(self.port) + self.server = None + self._started = Queue.Queue(1) + self.thread = self._start_server() + self._started.get(timeout=5) + + def _start_server(self): + + class BottleServerAdapter(bottle.ServerAdapter): + proxy = self + + def close_session(self): + self.proxy.ctx.model.log._session.remove() + + def run(self, app): + + class Server(wsgiref.simple_server.WSGIServer): + allow_reuse_address = True + bottle_server = self + + def handle_error(self, request, client_address): + pass + + def serve_forever(self, poll_interval=0.5): + try: + wsgiref.simple_server.WSGIServer.serve_forever(self, poll_interval) + finally: + # Once shutdown is called, we need to close the session. + # If the session is not closed properly, it might raise warnings, + # or even lock the database. + self.bottle_server.close_session() + + class Handler(wsgiref.simple_server.WSGIRequestHandler): + def address_string(self): + return self.client_address[0] + + def log_request(*args, **kwargs): # pylint: disable=no-method-argument + if not self.quiet: + return wsgiref.simple_server.WSGIRequestHandler.log_request(*args, + **kwargs) + server = wsgiref.simple_server.make_server( + host=self.host, + port=self.port, + app=app, + server_class=Server, + handler_class=Handler) + self.proxy.server = server + self.proxy._started.put(True) + server.serve_forever(poll_interval=0.1) + + def serve(): + # Since task is a thread_local object, we need to patch it inside the server thread. + self._ctx_patcher(self.ctx) + + bottle_app = bottle.Bottle() + bottle_app.post('/', callback=self._request_handler) + bottle.run( + app=bottle_app, + host='localhost', + port=self.port, + quiet=True, + server=BottleServerAdapter) + thread = threading.Thread(target=serve) + thread.daemon = True + thread.start() + return thread + + def close(self): + if self.server: + self.server.shutdown() + self.server.server_close() + + def _request_handler(self): + request = bottle.request.body.read() # pylint: disable=no-member + response = self._process(request) + return bottle.LocalResponse( + body=json.dumps(response, cls=modeling.utils.ModelJSONEncoder), + status=200, + headers={'content-type': 'application/json'} + ) + + def _process(self, request): + try: + with self.ctx.model.instrument(*self.ctx.INSTRUMENTATION_FIELDS): + payload = _process_request(self.ctx, request) + result_type = 'result' + if isinstance(payload, exceptions.ScriptException): + payload = dict(message=str(payload)) + result_type = 'stop_operation' + result = {'type': result_type, 'payload': payload} + except Exception as e: + traceback_out = StringIO.StringIO() + traceback.print_exc(file=traceback_out) + payload = { + 'type': type(e).__name__, + 'message': str(e), + 'traceback': traceback_out.getvalue() + } + result = {'type': 'error', 'payload': payload} + + return result + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class CtxError(RuntimeError): + pass + + +class CtxParsingError(CtxError): + pass + + +def _process_request(ctx, request): + request = json.loads(request) + args = request['args'] + return _process_arguments(ctx, args) + + +def _process_arguments(obj, args): + # Modifying? + try: + # TODO: should there be a way to escape "=" in case it is needed as real argument? + equals_index = args.index('=') # raises ValueError if not found + except ValueError: + equals_index = None + if equals_index is not None: + if equals_index == 0: + raise CtxParsingError('The "=" argument cannot be first') + elif equals_index != len(args) - 2: + raise CtxParsingError('The "=" argument must be penultimate') + modifying = True + modifying_key = args[-3] + modifying_value = args[-1] + args = args[:-3] + else: + modifying = False + modifying_key = None + modifying_value = None + + # Parse all arguments + while len(args) > 0: + obj, args = _process_next_operation(obj, args, modifying) + + if modifying: + if hasattr(obj, '__setitem__'): + # Modify item value (dict, list, and similar) + if isinstance(obj, (list, tuple)): + modifying_key = int(modifying_key) + obj[modifying_key] = modifying_value + elif hasattr(obj, modifying_key): + # Modify object attribute + setattr(obj, modifying_key, modifying_value) + else: + raise CtxError('Cannot modify `{0}` of `{1!r}`'.format(modifying_key, obj)) + + return obj + + +def _process_next_operation(obj, args, modifying): + args = list(args) + arg = args.pop(0) + + # Call? + if arg == '[': + # TODO: should there be a way to escape "[" and "]" in case they are needed as real + # arguments? + try: + closing_index = args.index(']') # raises ValueError if not found + except ValueError: + raise CtxParsingError('Opening "[" without a closing "]') + callable_args = args[:closing_index] + args = args[closing_index + 1:] + if not callable(obj): + raise CtxError('Used "[" and "] on an object that is not callable') + return obj(*callable_args), args + + # Attribute? + if isinstance(arg, basestring): + if hasattr(obj, arg): + return getattr(obj, arg), args + token_sugared = arg.replace('-', '_') + if hasattr(obj, token_sugared): + return getattr(obj, token_sugared), args + + # Item? (dict, lists, and similar) + if hasattr(obj, '__getitem__'): + if modifying and (arg not in obj) and hasattr(obj, '__setitem__'): + # Create nested dict + obj[arg] = {} + return obj[arg], args + + raise CtxParsingError('Cannot parse argument: `{0!r}`'.format(arg)) + + +def _get_unused_port(): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) + _, port = sock.getsockname() + sock.close() + return port diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/environment_globals.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/environment_globals.py new file mode 100644 index 0000000..6dec293 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/environment_globals.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for managing globals for the environment. +""" + +def create_initial_globals(path): + """ + Emulates a ``globals()`` call in a freshly loaded module. + + The implementation of this function is likely to raise a couple of questions. If you read the + implementation and nothing bothered you, feel free to skip the rest of this docstring. + + First, why is this function in its own module and not, say, in the same module of the other + environment-related functions? Second, why is it implemented in such a way that copies the + globals, then deletes the item that represents this function, and then changes some other + entries? + + Well, these two questions can be answered with one (elaborate) explanation. If this function was + in the same module with the other environment-related functions, then we would have had to + delete more items in globals than just ``create_initial_globals``. That is because all of the + other function names would also be in globals, and since there is no built-in mechanism that + return the name of the user-defined objects, this approach is quite an overkill. + + *But why do we rely on the copy-existing-globals-and-delete-entries method, when it seems to + force us to put ``create_initial_globals`` in its own file?* + + Well, because there is no easier method of creating globals of a newly loaded module. + + *How about hard coding a ``globals`` dict? It seems that there are very few entries: + ``__doc__``, ``__file__``, ``__name__``, ``__package__`` (but don't forget ``__builtins__``).* + + That would be coupling our implementation to a specific ``globals`` implementation. What if + ``globals`` were to change? + """ + copied_globals = globals().copy() + copied_globals.update({ + '__doc__': 'Dynamically executed script', + '__file__': path, + '__name__': '__main__', + '__package__': None + }) + del copied_globals[create_initial_globals.__name__] + return copied_globals diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/exceptions.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/exceptions.py new file mode 100644 index 0000000..f201fae --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/exceptions.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Execution plugin exceptions. +""" + +class ProcessException(Exception): + """ + Raised when local scripts and remote SSH commands fail. + """ + + def __init__(self, stderr=None, stdout=None, command=None, exit_code=None): + super(ProcessException, self).__init__(stderr) + self.command = command + self.exit_code = exit_code + self.stdout = stdout + self.stderr = stderr + + +class TaskException(Exception): + """ + Raised when remote ssh scripts fail. + """ + + +class ScriptException(Exception): + """ + Used by the ``ctx`` proxy server when task.retry or task.abort are called by scripts. + """ + + def __init__(self, message=None, retry=None, retry_interval=None): + super(ScriptException, self).__init__(message) + self.retry = retry + self.retry_interval = retry_interval diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/instantiation.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/instantiation.py new file mode 100644 index 0000000..8b52015 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/instantiation.py @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Instantiation of :class:`~aria.modeling.models.Operation` models. +""" + +# TODO: this module will eventually be moved to a new "aria.instantiation" package +from ...modeling.functions import Function +from ... import utils + + +def configure_operation(operation, reporter): + host = None + interface = operation.interface + if interface.node is not None: + host = interface.node.host + elif interface.relationship is not None: + if operation.relationship_edge is True: + host = interface.relationship.target_node.host + else: # either False or None (None meaning that edge was not specified) + host = interface.relationship.source_node.host + + _configure_common(operation, reporter) + if host is None: + _configure_local(operation) + else: + _configure_remote(operation, reporter) + + # Any remaining un-handled configuration parameters will become extra arguments, available as + # kwargs in either "run_script_locally" or "run_script_with_ssh" + for key, value in operation.configurations.iteritems(): + if key not in ('process', 'ssh'): + operation.arguments[key] = value.instantiate(None) + + +def _configure_common(operation, reporter): + """ + Local and remote operations. + """ + + from ...modeling.models import Argument + operation.arguments['script_path'] = Argument.wrap('script_path', operation.implementation, + 'Relative path to the executable file.') + operation.arguments['process'] = Argument.wrap('process', _get_process(operation, reporter), + 'Sub-process configuration.') + + +def _configure_local(operation): + """ + Local operation. + """ + + from . import operations + operation.function = '{0}.{1}'.format(operations.__name__, + operations.run_script_locally.__name__) + + +def _configure_remote(operation, reporter): + """ + Remote SSH operation via Fabric. + """ + + from ...modeling.models import Argument + from . import operations + + ssh = _get_ssh(operation, reporter) + + # Defaults + # TODO: find a way to configure these generally in the service template + default_user = '' + default_password = '' + if 'user' not in ssh: + ssh['user'] = default_user + if ('password' not in ssh) and ('key' not in ssh) and ('key_filename' not in ssh): + ssh['password'] = default_password + + operation.arguments['use_sudo'] = Argument.wrap('use_sudo', ssh.get('use_sudo', False), + 'Whether to execute with sudo.') + + operation.arguments['hide_output'] = Argument.wrap('hide_output', ssh.get('hide_output', []), + 'Hide output of these Fabric groups.') + + fabric_env = {} + if 'warn_only' in ssh: + fabric_env['warn_only'] = ssh['warn_only'] + fabric_env['user'] = ssh.get('user') + fabric_env['password'] = ssh.get('password') + fabric_env['key'] = ssh.get('key') + fabric_env['key_filename'] = ssh.get('key_filename') + if 'address' in ssh: + fabric_env['host_string'] = ssh['address'] + + # Make sure we have a user + if fabric_env.get('user') is None: + reporter.report('must configure "ssh.user" for "{0}"'.format(operation.implementation), + level=reporter.Issue.BETWEEN_TYPES) + + # Make sure we have an authentication value + if (fabric_env.get('password') is None) and \ + (fabric_env.get('key') is None) and \ + (fabric_env.get('key_filename') is None): + reporter.report( + 'must configure "ssh.password", "ssh.key", or "ssh.key_filename" for "{0}"' + .format(operation.implementation), + level=reporter.Issue.BETWEEN_TYPES) + + operation.arguments['fabric_env'] = Argument.wrap('fabric_env', fabric_env, + 'Fabric configuration.') + + operation.function = '{0}.{1}'.format(operations.__name__, + operations.run_script_with_ssh.__name__) + + +def _get_process(operation, reporter): + value = (operation.configurations.get('process')._value + if 'process' in operation.configurations + else None) + if value is None: + return {} + _validate_type(value, dict, 'process', reporter) + value = utils.collections.OrderedDict(value) + for k, v in value.iteritems(): + if k == 'eval_python': + value[k] = _coerce_bool(v, 'process.eval_python', reporter) + elif k == 'cwd': + _validate_type(v, basestring, 'process.cwd', reporter) + elif k == 'command_prefix': + _validate_type(v, basestring, 'process.command_prefix', reporter) + elif k == 'args': + value[k] = _dict_to_list_of_strings(v, 'process.args', reporter) + elif k == 'env': + _validate_type(v, dict, 'process.env', reporter) + else: + reporter.report('unsupported configuration parameter: "process.{0}"'.format(k), + level=reporter.Issue.BETWEEN_TYPES) + return value + + +def _get_ssh(operation, reporter): + value = (operation.configurations.get('ssh')._value + if 'ssh' in operation.configurations + else None) + if value is None: + return {} + _validate_type(value, dict, 'ssh', reporter) + value = utils.collections.OrderedDict(value) + for k, v in value.iteritems(): + if k == 'use_sudo': + value[k] = _coerce_bool(v, 'ssh.use_sudo', reporter) + elif k == 'hide_output': + value[k] = _dict_to_list_of_strings(v, 'ssh.hide_output', reporter) + elif k == 'warn_only': + value[k] = _coerce_bool(v, 'ssh.warn_only', reporter) + elif k == 'user': + _validate_type(v, basestring, 'ssh.user', reporter) + elif k == 'password': + _validate_type(v, basestring, 'ssh.password', reporter) + elif k == 'key': + _validate_type(v, basestring, 'ssh.key', reporter) + elif k == 'key_filename': + _validate_type(v, basestring, 'ssh.key_filename', reporter) + elif k == 'address': + _validate_type(v, basestring, 'ssh.address', reporter) + else: + reporter.report('unsupported configuration parameter: "ssh.{0}"'.format(k), + level=reporter.Issue.BETWEEN_TYPES) + return value + + +def _validate_type(value, the_type, name, reporter): + if isinstance(value, Function): + return + if not isinstance(value, the_type): + reporter.report( + '"{0}" configuration is not a {1}: {2}'.format( + name, utils.type.full_type_name(the_type), utils.formatting.safe_repr(value)), + level=reporter.Issue.BETWEEN_TYPES) + + +def _coerce_bool(value, name, reporter): + if value is None: + return None + if isinstance(value, bool): + return value + _validate_type(value, basestring, name, reporter) + if value == 'true': + return True + elif value == 'false': + return False + else: + reporter.report( + '"{0}" configuration is not "true" or "false": {1}'.format( + name, utils.formatting.safe_repr(value)), + level=reporter.Issue.BETWEEN_TYPES) + + +def _dict_to_list_of_strings(the_dict, name, reporter): + _validate_type(the_dict, dict, name, reporter) + value = [] + for k in sorted(the_dict): + v = the_dict[k] + _validate_type(v, basestring, '{0}.{1}'.format(name, k), reporter) + value.append(v) + return value diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/local.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/local.py new file mode 100644 index 0000000..04b9ecd --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/local.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Local execution of operations. +""" + +import os +import subprocess +import threading +import StringIO + +from . import ctx_proxy +from . import exceptions +from . import common +from . import constants +from . import environment_globals +from . import python_script_scope + + +def run_script(ctx, script_path, process, **kwargs): + if not script_path: + ctx.task.abort('Missing script_path') + process = process or {} + script_path = common.download_script(ctx, script_path) + script_func = _get_run_script_func(script_path, process) + return script_func( + ctx=ctx, + script_path=script_path, + process=process, + operation_kwargs=kwargs) + + +def _get_run_script_func(script_path, process): + if _treat_script_as_python_script(script_path, process): + return _eval_script_func + else: + if _treat_script_as_powershell_script(script_path): + process.setdefault('command_prefix', constants.DEFAULT_POWERSHELL_EXECUTABLE) + return _execute_func + + +def _treat_script_as_python_script(script_path, process): + eval_python = process.get('eval_python') + script_extension = os.path.splitext(script_path)[1].lower() + return (eval_python is True or (script_extension == constants.PYTHON_SCRIPT_FILE_EXTENSION and + eval_python is not False)) + + +def _treat_script_as_powershell_script(script_path): + script_extension = os.path.splitext(script_path)[1].lower() + return script_extension == constants.POWERSHELL_SCRIPT_FILE_EXTENSION + + +def _eval_script_func(script_path, ctx, operation_kwargs, **_): + with python_script_scope(operation_ctx=ctx, operation_inputs=operation_kwargs): + execfile(script_path, environment_globals.create_initial_globals(script_path)) + + +def _execute_func(script_path, ctx, process, operation_kwargs): + os.chmod(script_path, 0755) + process = common.create_process_config( + script_path=script_path, + process=process, + operation_kwargs=operation_kwargs) + command = process['command'] + env = os.environ.copy() + env.update(process['env']) + ctx.logger.info('Executing: {0}'.format(command)) + with ctx_proxy.server.CtxProxy(ctx, common.patch_ctx) as proxy: + env[ctx_proxy.client.CTX_SOCKET_URL] = proxy.socket_url + running_process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + cwd=process.get('cwd'), + bufsize=1, + close_fds=not common.is_windows()) + stdout_consumer = _OutputConsumer(running_process.stdout) + stderr_consumer = _OutputConsumer(running_process.stderr) + exit_code = running_process.wait() + stdout_consumer.join() + stderr_consumer.join() + ctx.logger.info('Execution done (exit_code={0}): {1}'.format(exit_code, command)) + + def error_check_func(): + if exit_code: + raise exceptions.ProcessException( + command=command, + exit_code=exit_code, + stdout=stdout_consumer.read_output(), + stderr=stderr_consumer.read_output()) + return common.check_error(ctx, error_check_func=error_check_func) + + +class _OutputConsumer(object): + + def __init__(self, out): + self._out = out + self._buffer = StringIO.StringIO() + self._consumer = threading.Thread(target=self._consume_output) + self._consumer.daemon = True + self._consumer.start() + + def _consume_output(self): + for line in iter(self._out.readline, b''): + self._buffer.write(line) + self._out.close() + + def read_output(self): + return self._buffer.getvalue() + + def join(self): + self._consumer.join() diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/operations.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/operations.py new file mode 100644 index 0000000..0e987f4 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/operations.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Entry point functions. +""" + +from aria.orchestrator import operation +from . import local as local_operations + + +@operation +def run_script_locally(ctx, + script_path, + process=None, + **kwargs): + return local_operations.run_script( + ctx=ctx, + script_path=script_path, + process=process, + **kwargs) + + +@operation +def run_script_with_ssh(ctx, + script_path, + fabric_env=None, + process=None, + use_sudo=False, + hide_output=None, + **kwargs): + return _try_import_ssh().run_script( + ctx=ctx, + script_path=script_path, + fabric_env=fabric_env, + process=process, + use_sudo=use_sudo, + hide_output=hide_output, + **kwargs) + + +@operation +def run_commands_with_ssh(ctx, + commands, + fabric_env=None, + use_sudo=False, + hide_output=None, + **_): + return _try_import_ssh().run_commands( + ctx=ctx, + commands=commands, + fabric_env=fabric_env, + use_sudo=use_sudo, + hide_output=hide_output) + + +def _try_import_ssh(): + try: + from .ssh import operations as ssh_operations + return ssh_operations + except Exception as e: + print(e) + raise RuntimeError('Failed to import SSH modules; Have you installed the ARIA SSH extra?') diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/__init__.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/__init__.py new file mode 100644 index 0000000..474deef --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Remote execution of operations over SSH. +""" diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/operations.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/operations.py new file mode 100644 index 0000000..c40e783 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/operations.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for running commands remotely over SSH. +""" + +import os +import random +import string +import tempfile +import StringIO + +import fabric.api +import fabric.context_managers +import fabric.contrib.files + +from .. import constants +from .. import exceptions +from .. import common +from .. import ctx_proxy +from . import tunnel + + +_PROXY_CLIENT_PATH = ctx_proxy.client.__file__ +if _PROXY_CLIENT_PATH.endswith('.pyc'): + _PROXY_CLIENT_PATH = _PROXY_CLIENT_PATH[:-1] + + +def run_commands(ctx, commands, fabric_env, use_sudo, hide_output, **_): + """Runs the provider 'commands' in sequence + + :param commands: a list of commands to run + :param fabric_env: fabric configuration + """ + with fabric.api.settings(_hide_output(ctx, groups=hide_output), + **_fabric_env(ctx, fabric_env, warn_only=True)): + for command in commands: + ctx.logger.info('Running command: {0}'.format(command)) + run = fabric.api.sudo if use_sudo else fabric.api.run + result = run(command) + if result.failed: + raise exceptions.ProcessException( + command=result.command, + exit_code=result.return_code, + stdout=result.stdout, + stderr=result.stderr) + + +def run_script(ctx, script_path, fabric_env, process, use_sudo, hide_output, **kwargs): + process = process or {} + paths = _Paths(base_dir=process.get('base_dir', constants.DEFAULT_BASE_DIR), + local_script_path=common.download_script(ctx, script_path)) + with fabric.api.settings(_hide_output(ctx, groups=hide_output), + **_fabric_env(ctx, fabric_env, warn_only=False)): + # the remote host must have the ctx before running any fabric scripts + if not fabric.contrib.files.exists(paths.remote_ctx_path): + # there may be race conditions with other operations that + # may be running in parallel, so we pass -p to make sure + # we get 0 exit code if the directory already exists + fabric.api.run('mkdir -p {0} && mkdir -p {1}'.format(paths.remote_scripts_dir, + paths.remote_work_dir)) + # this file has to be present before using ctx + fabric.api.put(_PROXY_CLIENT_PATH, paths.remote_ctx_path) + process = common.create_process_config( + script_path=paths.remote_script_path, + process=process, + operation_kwargs=kwargs, + quote_json_env_vars=True) + fabric.api.put(paths.local_script_path, paths.remote_script_path) + with ctx_proxy.server.CtxProxy(ctx, _patch_ctx) as proxy: + local_port = proxy.port + with fabric.context_managers.cd(process.get('cwd', paths.remote_work_dir)): # pylint: disable=not-context-manager + with tunnel.remote(ctx, local_port=local_port) as remote_port: + local_socket_url = proxy.socket_url + remote_socket_url = local_socket_url.replace(str(local_port), str(remote_port)) + env_script = _write_environment_script_file( + process=process, + paths=paths, + local_socket_url=local_socket_url, + remote_socket_url=remote_socket_url) + fabric.api.put(env_script, paths.remote_env_script_path) + try: + command = 'source {0} && {1}'.format(paths.remote_env_script_path, + process['command']) + run = fabric.api.sudo if use_sudo else fabric.api.run + run(command) + except exceptions.TaskException: + return common.check_error(ctx, reraise=True) + return common.check_error(ctx) + + +def _patch_ctx(ctx): + common.patch_ctx(ctx) + original_download_resource = ctx.download_resource + original_download_resource_and_render = ctx.download_resource_and_render + + def _download_resource(func, destination, **kwargs): + handle, temp_local_path = tempfile.mkstemp() + os.close(handle) + try: + func(destination=temp_local_path, **kwargs) + return fabric.api.put(temp_local_path, destination) + finally: + os.remove(temp_local_path) + + def download_resource(destination, path=None): + _download_resource( + func=original_download_resource, + destination=destination, + path=path) + ctx.download_resource = download_resource + + def download_resource_and_render(destination, path=None, variables=None): + _download_resource( + func=original_download_resource_and_render, + destination=destination, + path=path, + variables=variables) + ctx.download_resource_and_render = download_resource_and_render + + +def _hide_output(ctx, groups): + """ Hides Fabric's output for every 'entity' in `groups` """ + groups = set(groups or []) + if not groups.issubset(constants.VALID_FABRIC_GROUPS): + ctx.task.abort('`hide_output` must be a subset of {0} (Provided: {1})' + .format(', '.join(constants.VALID_FABRIC_GROUPS), ', '.join(groups))) + return fabric.api.hide(*groups) + + +def _fabric_env(ctx, fabric_env, warn_only): + """Prepares fabric environment variables configuration""" + ctx.logger.debug('Preparing fabric environment...') + env = constants.FABRIC_ENV_DEFAULTS.copy() + env.update(fabric_env or {}) + env.setdefault('warn_only', warn_only) + # validations + if (not env.get('host_string')) and (ctx.task) and (ctx.task.actor) and (ctx.task.actor.host): + env['host_string'] = ctx.task.actor.host.host_address + if not env.get('host_string'): + ctx.task.abort('`host_string` not supplied and ip cannot be deduced automatically') + if not (env.get('password') or env.get('key_filename') or env.get('key')): + ctx.task.abort( + 'Access credentials not supplied ' + '(you must supply at least one of `key_filename`, `key` or `password`)') + if not env.get('user'): + ctx.task.abort('`user` not supplied') + ctx.logger.debug('Environment prepared successfully') + return env + + +def _write_environment_script_file(process, paths, local_socket_url, remote_socket_url): + env_script = StringIO.StringIO() + env = process['env'] + env['PATH'] = '{0}:$PATH'.format(paths.remote_ctx_dir) + env['PYTHONPATH'] = '{0}:$PYTHONPATH'.format(paths.remote_ctx_dir) + env_script.write('chmod +x {0}\n'.format(paths.remote_script_path)) + env_script.write('chmod +x {0}\n'.format(paths.remote_ctx_path)) + env.update({ + ctx_proxy.client.CTX_SOCKET_URL: remote_socket_url, + 'LOCAL_{0}'.format(ctx_proxy.client.CTX_SOCKET_URL): local_socket_url + }) + for key, value in env.iteritems(): + env_script.write('export {0}={1}\n'.format(key, value)) + return env_script + + +class _Paths(object): + + def __init__(self, base_dir, local_script_path): + self.local_script_path = local_script_path + self.remote_ctx_dir = base_dir + self.base_script_path = os.path.basename(self.local_script_path) + self.remote_ctx_path = '{0}/ctx'.format(self.remote_ctx_dir) + self.remote_scripts_dir = '{0}/scripts'.format(self.remote_ctx_dir) + self.remote_work_dir = '{0}/work'.format(self.remote_ctx_dir) + random_suffix = ''.join(random.choice(string.ascii_lowercase + string.digits) + for _ in range(8)) + remote_path_suffix = '{0}-{1}'.format(self.base_script_path, random_suffix) + self.remote_env_script_path = '{0}/env-{1}'.format(self.remote_scripts_dir, + remote_path_suffix) + self.remote_script_path = '{0}/{1}'.format(self.remote_scripts_dir, remote_path_suffix) diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/tunnel.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/tunnel.py new file mode 100644 index 0000000..e76d525 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/execution_plugin/ssh/tunnel.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This implementation was copied from the Fabric project directly: +# https://github.com/fabric/fabric/blob/master/fabric/context_managers.py#L486 +# The purpose was to remove the rtunnel creation printouts here: +# https://github.com/fabric/fabric/blob/master/fabric/context_managers.py#L547 + + +import contextlib +import select +import socket + +import fabric.api +import fabric.state +import fabric.thread_handling + + +@contextlib.contextmanager +def remote(ctx, local_port, remote_port=0, local_host='localhost', remote_bind_address='127.0.0.1'): + """Create a tunnel forwarding a locally-visible port to the remote target.""" + sockets = [] + channels = [] + thread_handlers = [] + + def accept(channel, *args, **kwargs): + # This seemingly innocent statement seems to be doing nothing + # but the truth is far from it! + # calling fileno() on a paramiko channel the first time, creates + # the required plumbing to make the channel valid for select. + # While this would generally happen implicitly inside the _forwarder + # function when select is called, it may already be too late and may + # cause the select loop to hang. + # Specifically, when new data arrives to the channel, a flag is set + # on an "event" object which is what makes the select call work. + # problem is this will only happen if the event object is not None + # and it will be not-None only after channel.fileno() has been called + # for the first time. If we wait until _forwarder calls select for the + # first time it may be after initial data has reached the channel. + # calling it explicitly here in the paramiko transport main event loop + # guarantees this will not happen. + channel.fileno() + + channels.append(channel) + sock = socket.socket() + sockets.append(sock) + + try: + sock.connect((local_host, local_port)) + except Exception as e: + try: + channel.close() + except Exception as ex2: + close_error = ' (While trying to close channel: {0})'.format(ex2) + else: + close_error = '' + ctx.task.abort('[{0}] rtunnel: cannot connect to {1}:{2} ({3}){4}' + .format(fabric.api.env.host_string, local_host, local_port, e, + close_error)) + + thread_handler = fabric.thread_handling.ThreadHandler('fwd', _forwarder, channel, sock) + thread_handlers.append(thread_handler) + + transport = fabric.state.connections[fabric.api.env.host_string].get_transport() + remote_port = transport.request_port_forward( + remote_bind_address, remote_port, handler=accept) + + try: + yield remote_port + finally: + for sock, chan, thread_handler in zip(sockets, channels, thread_handlers): + sock.close() + chan.close() + thread_handler.thread.join() + thread_handler.raise_if_needed() + transport.cancel_port_forward(remote_bind_address, remote_port) + + +def _forwarder(chan, sock): + # Bidirectionally forward data between a socket and a Paramiko channel. + while True: + read = select.select([sock, chan], [], [])[0] + if sock in read: + data = sock.recv(1024) + if len(data) == 0: + break + chan.send(data) + if chan in read: + data = chan.recv(1024) + if len(data) == 0: + break + sock.send(data) + chan.close() + sock.close() |