summaryrefslogtreecommitdiffstats
path: root/azure/aria/aria-extension-cloudify/src/aria/tests/orchestrator/workflows/executor/__init__.py
blob: 99d0b39f2b6942b16550b2cef5a0f1a6100d46f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import uuid
from contextlib import contextmanager

import aria
from aria.modeling import models
from aria.orchestrator.context.common import BaseContext


class MockContext(object):

    INSTRUMENTATION_FIELDS = BaseContext.INSTRUMENTATION_FIELDS

    def __init__(self, storage, task_kwargs=None):
        self.logger = logging.getLogger('mock_logger')
        self._task_kwargs = task_kwargs or {}
        self._storage = storage
        self.task = MockTask(storage, **task_kwargs)
        self.states = []
        self.exception = None

    @property
    def serialization_dict(self):
        return {
            'context_cls': self.__class__,
            'context': {
                'storage_kwargs': self._storage.serialization_dict,
                'task_kwargs': self._task_kwargs
            }
        }

    def __getattr__(self, item):
        return None

    def close(self):
        pass

    @property
    def model(self):
        return self._storage

    @classmethod
    def instantiate_from_dict(cls, storage_kwargs=None, task_kwargs=None):
        return cls(storage=aria.application_model_storage(**(storage_kwargs or {})),
                   task_kwargs=(task_kwargs or {}))

    @property
    @contextmanager
    def persist_changes(self):
        yield


class MockActor(object):
    def __init__(self):
        self.name = 'actor_name'


class MockTask(object):

    INFINITE_RETRIES = models.Task.INFINITE_RETRIES

    def __init__(self, model, function, arguments=None, plugin_fk=None):
        self.function = self.name = function
        self.plugin_fk = plugin_fk
        self.arguments = arguments or {}
        self.states = []
        self.exception = None
        self.id = str(uuid.uuid4())
        self.logger = logging.getLogger()
        self.attempts_count = 1
        self.max_attempts = 1
        self.ignore_failure = False
        self.interface_name = 'interface_name'
        self.operation_name = 'operation_name'
        self.actor = MockActor()
        self.node = self.actor
        self.model = model

        for state in models.Task.STATES:
            setattr(self, state.upper(), state)

    @property
    def plugin(self):
        return self.model.plugin.get(self.plugin_fk) if self.plugin_fk else None