summaryrefslogtreecommitdiffstats
path: root/azure/aria/aria-extension-cloudify/src/aria/aria/parser/presentation/utils.py
blob: f0fd39018d16cacdd2555b551c83522e6d7f7c28 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from types import FunctionType

from ...utils.formatting import safe_repr
from ...utils.type import full_type_name
from ..validation import Issue
from .null import NULL


def get_locator(*values):
    """
    Gets the first available locator.

    :rtype: :class:`aria.parser.reading.Locator`
    """

    for v in values:
        if hasattr(v, '_locator'):
            locator = v._locator
            if locator is not None:
                return locator
    return None


def parse_types_dict_names(types_dict_names):
    """
    If the first element in the array is a function, extracts it out.
    """

    convert = None
    if isinstance(types_dict_names[0], FunctionType):
        convert = types_dict_names[0]
        types_dict_names = types_dict_names[1:]
    return types_dict_names, convert


def validate_primitive(value, cls, coerce=False):
    """
    Checks if the value is of the primitive type, optionally attempting to coerce it
    if it is not.

    :raises ValueError: if not a primitive type or if coercion failed.
    """

    if (cls is not None) and (value is not None) and (value is not NULL):
        if (cls is unicode) or (cls is str): # These two types are interchangeable
            valid = isinstance(value, basestring)
        elif cls is int:
            # In Python, a bool is an int
            valid = isinstance(value, int) and not isinstance(value, bool)
        else:
            valid = isinstance(value, cls)
        if not valid:
            if coerce:
                value = cls(value)
            else:
                raise ValueError('not a "%s": %s' % (full_type_name(cls), safe_repr(value)))
    return value


def validate_no_short_form(context, presentation):
    """
    Makes sure that we can use short form definitions only if we allowed it.
    """

    if not hasattr(presentation, 'SHORT_FORM_FIELD') and not isinstance(presentation._raw, dict):
        context.validation.report('short form not allowed for field "%s"' % presentation._fullname,
                                  locator=presentation._locator,
                                  level=Issue.BETWEEN_FIELDS)


def validate_no_unknown_fields(context, presentation):
    """
    Make sure that we can use unknown fields only if we allowed it.
    """

    if not getattr(presentation, 'ALLOW_UNKNOWN_FIELDS', False) \
            and not context.validation.allow_unknown_fields \
            and isinstance(presentation._raw, dict) \
            and hasattr(presentation, 'FIELDS'):
        for k in presentation._raw:
            if k not in presentation.FIELDS:
                context.validation.report('field "%s" is not supported in "%s"'
                                          % (k, presentation._fullname),
                                          locator=presentation._get_child_locator(k),
                                          level=Issue.BETWEEN_FIELDS)


def validate_known_fields(context, presentation):
    """
    Validates all known fields.
    """

    if hasattr(presentation, '_iter_fields'):
        for _, field in presentation._iter_fields():
            field.validate(presentation, context)


def get_parent_presentation(context, presentation, *types_dict_names):
    """
    Returns the parent presentation according to the ``derived_from`` field, or ``None`` if invalid.

    Checks that we do not derive from ourselves and that we do not cause a circular hierarchy.

    The arguments from the third onwards are used to locate a nested field under
    ``service_template`` under the root presenter. The first of these can optionally be a function,
    in which case it will be called to convert type names. This can be used to support shorthand
    type names, aliases, etc.
    """

    type_name = presentation.derived_from

    if type_name is None:
        return None

    types_dict_names, convert = parse_types_dict_names(types_dict_names)
    types_dict = context.presentation.get('service_template', *types_dict_names) or {}

    if convert:
        type_name = convert(context, type_name, types_dict)

    # Make sure not derived from self
    if type_name == presentation._name:
        return None
    # Make sure derived from type exists
    elif type_name not in types_dict:
        return None
    else:
        # Make sure derivation hierarchy is not circular
        hierarchy = [presentation._name]
        presentation_copy = presentation
        while presentation_copy.derived_from is not None:
            derived_from = presentation_copy.derived_from
            if convert:
                derived_from = convert(context, derived_from, types_dict)

            if derived_from == presentation_copy._name or derived_from not in types_dict:
                return None
            presentation_copy = types_dict[derived_from]
            if presentation_copy._name in hierarchy:
                return None
            hierarchy.append(presentation_copy._name)

    return types_dict[type_name]


def report_issue_for_unknown_type(context, presentation, type_name, field_name, value=None):
    if value is None:
        value = getattr(presentation, field_name)
    context.validation.report('"%s" refers to an unknown %s in "%s": %s'
                              % (field_name, type_name, presentation._fullname, safe_repr(value)),
                              locator=presentation._get_child_locator(field_name),
                              level=Issue.BETWEEN_TYPES)


def report_issue_for_parent_is_self(context, presentation, field_name):
    context.validation.report('parent type of "%s" is self' % presentation._fullname,
                              locator=presentation._get_child_locator(field_name),
                              level=Issue.BETWEEN_TYPES)


def report_issue_for_unknown_parent_type(context, presentation, field_name):
    context.validation.report('unknown parent type "%s" in "%s"'
                              % (getattr(presentation, field_name), presentation._fullname),
                              locator=presentation._get_child_locator(field_name),
                              level=Issue.BETWEEN_TYPES)


def report_issue_for_circular_type_hierarchy(context, presentation, field_name):
    context.validation.report('"%s" of "%s" creates a circular type hierarchy'
                              % (getattr(presentation, field_name), presentation._fullname),
                              locator=presentation._get_child_locator(field_name),
                              level=Issue.BETWEEN_TYPES)