summaryrefslogtreecommitdiffstats
path: root/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py
blob: 975ada7028ad9baf36586879b7c77e5a47cfc7c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
SQLAlchemy implementation of the storage model API ("MAPI").
"""

import os
import platform

from sqlalchemy import (
    create_engine,
    orm,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.exc import StaleDataError

from aria.utils.collections import OrderedDict
from . import (
    api,
    exceptions,
    collection_instrumentation
)

_predicates = {'ge': '__ge__',
               'gt': '__gt__',
               'lt': '__lt__',
               'le': '__le__',
               'eq': '__eq__',
               'ne': '__ne__'}


class SQLAlchemyModelAPI(api.ModelAPI):
    """
    SQLAlchemy implementation of the storage model API ("MAPI").
    """

    def __init__(self,
                 engine,
                 session,
                 **kwargs):
        super(SQLAlchemyModelAPI, self).__init__(**kwargs)
        self._engine = engine
        self._session = session

    def get(self, entry_id, include=None, **kwargs):
        """
        Returns a single result based on the model class and element ID
        """
        query = self._get_query(include, {'id': entry_id})
        result = query.first()

        if not result:
            raise exceptions.NotFoundError(
                'Requested `{0}` with ID `{1}` was not found'
                .format(self.model_cls.__name__, entry_id)
            )
        return self._instrument(result)

    def get_by_name(self, entry_name, include=None, **kwargs):
        assert hasattr(self.model_cls, 'name')
        result = self.list(include=include, filters={'name': entry_name})
        if not result:
            raise exceptions.NotFoundError(
                'Requested {0} with name `{1}` was not found'
                .format(self.model_cls.__name__, entry_name)
            )
        elif len(result) > 1:
            raise exceptions.StorageError(
                'Requested {0} with name `{1}` returned more than 1 value'
                .format(self.model_cls.__name__, entry_name)
            )
        else:
            return result[0]

    def list(self,
             include=None,
             filters=None,
             pagination=None,
             sort=None,
             **kwargs):
        query = self._get_query(include, filters, sort)

        results, total, size, offset = self._paginate(query, pagination)

        return ListResult(
            dict(total=total, size=size, offset=offset),
            [self._instrument(result) for result in results]
        )

    def iter(self,
             include=None,
             filters=None,
             sort=None,
             **kwargs):
        """
        Returns a (possibly empty) list of ``model_class`` results.
        """
        for result in self._get_query(include, filters, sort):
            yield self._instrument(result)

    def put(self, entry, **kwargs):
        """
        Creatse a ``model_class`` instance from a serializable ``model`` object.

        :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict``
         method, and whose attributes match the columns of ``model_class`` (might also be just an
         instance of ``model_class``)
        :return: an instance of ``model_class``
        """
        self._session.add(entry)
        self._safe_commit()
        return entry

    def delete(self, entry, **kwargs):
        """
        Deletes a single result based on the model class and element ID.
        """
        self._load_relationships(entry)
        self._session.delete(entry)
        self._safe_commit()
        return entry

    def update(self, entry, **kwargs):
        """
        Adds ``instance`` to the database session, and attempts to commit.

        :return: updated instance
        """
        return self.put(entry)

    def refresh(self, entry):
        """
        Reloads the instance with fresh information from the database.

        :param entry: instance to be re-loaded from the database
        :return: refreshed instance
        """
        self._session.refresh(entry)
        self._load_relationships(entry)
        return entry

    def _destroy_connection(self):
        pass

    def _establish_connection(self):
        pass

    def create(self, checkfirst=True, create_all=True, **kwargs):
        self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)

        if create_all:
            # In order to create any models created dynamically (e.g. many-to-many helper tables are
            # created at runtime).
            self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)

    def drop(self):
        """
        Drops the table.
        """
        self.model_cls.__table__.drop(self._engine)

    def _safe_commit(self):
        """
        Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and
        rolls back if they're caught.
        """
        try:
            self._session.commit()
        except StaleDataError as e:
            self._session.rollback()
            raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
        except (SQLAlchemyError, ValueError) as e:
            self._session.rollback()
            raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))

    def _get_base_query(self, include, joins):
        """
        Create the initial query from the model class and included columns.

        :param include: (possibly empty) list of columns to include in the query
        :return: SQLAlchemy AppenderQuery object
        """
        # If only some columns are included, query through the session object
        if include:
            # Make sure that attributes come before association proxies
            include.sort(key=lambda x: x.is_clause_element)
            query = self._session.query(*include)
        else:
            # If all columns should be returned, query directly from the model
            query = self._session.query(self.model_cls)

        query = query.join(*joins)
        return query

    @staticmethod
    def _get_joins(model_class, columns):
        """
        Gets a list of all the tables on which we need to join.

        :param columns: set of all attributes involved in the query
        """

        # Using a list instead of a set because order is important
        joins = OrderedDict()
        for column_name in columns:
            column = getattr(model_class, column_name)
            while not column.is_attribute:
                join_attr = column.local_attr
                # This is a hack, to deal with the fact that SQLA doesn't
                # fully support doing something like: `if join_attr in joins`,
                # because some SQLA elements have their own comparators
                join_attr_name = str(join_attr)
                if join_attr_name not in joins:
                    joins[join_attr_name] = join_attr
                column = column.remote_attr

        return joins.values()

    @staticmethod
    def _sort_query(query, sort=None):
        """
        Adds sorting clauses to the query.

        :param query: base SQL query
        :param sort: optional dictionary where keys are column names to sort by, and values are
         the order (asc/desc)
        :return: SQLAlchemy AppenderQuery object
        """
        if sort:
            for column, order in sort.items():
                if order == 'desc':
                    column = column.desc()
                query = query.order_by(column)
        return query

    def _filter_query(self, query, filters):
        """
        Adds filter clauses to the query.

        :param query: base SQL query
        :param filters: optional dictionary where keys are column names to filter by, and values
         are values applicable for those columns (or lists of such values)
        :return: SQLAlchemy AppenderQuery object
        """
        return self._add_value_filter(query, filters)

    @staticmethod
    def _add_value_filter(query, filters):
        for column, value in filters.items():
            if isinstance(value, dict):
                for predicate, operand in value.items():
                    query = query.filter(getattr(column, predicate)(operand))
            elif isinstance(value, (list, tuple)):
                query = query.filter(column.in_(value))
            else:
                query = query.filter(column == value)

        return query

    def _get_query(self,
                   include=None,
                   filters=None,
                   sort=None):
        """
        Gets a SQL query object based on the params passed.

        :param model_class: SQL database table class
        :param include: optional list of columns to include in the query
        :param filters: optional dictionary where keys are column names to filter by, and values
         are values applicable for those columns (or lists of such values)
        :param sort: optional dictionary where keys are column names to sort by, and values are the
         order (asc/desc)
        :return: sorted and filtered query with only the relevant columns
        """
        include, filters, sort, joins = self._get_joins_and_converted_columns(
            include, filters, sort
        )
        filters = self._convert_operands(filters)

        query = self._get_base_query(include, joins)
        query = self._filter_query(query, filters)
        query = self._sort_query(query, sort)
        return query

    @staticmethod
    def _convert_operands(filters):
        for column, conditions in filters.items():
            if isinstance(conditions, dict):
                for predicate, operand in conditions.items():
                    if predicate not in _predicates:
                        raise exceptions.StorageError(
                            "{0} is not a valid predicate for filtering. Valid predicates are {1}"
                            .format(predicate, ', '.join(_predicates.keys())))
                    del filters[column][predicate]
                    filters[column][_predicates[predicate]] = operand


        return filters

    def _get_joins_and_converted_columns(self,
                                         include,
                                         filters,
                                         sort):
        """
        Gets a list of tables on which we need to join and the converted ``include``, ``filters``
        and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of
        column names).
        """
        include = include or []
        filters = filters or dict()
        sort = sort or OrderedDict()

        all_columns = set(include) | set(filters.keys()) | set(sort.keys())
        joins = self._get_joins(self.model_cls, all_columns)

        include, filters, sort = self._get_columns_from_field_names(
            include, filters, sort
        )
        return include, filters, sort, joins

    def _get_columns_from_field_names(self,
                                      include,
                                      filters,
                                      sort):
        """
        Gooes over the optional parameters (include, filters, sort), and replace column names with
        actual SQLAlechmy column objects.
        """
        include = [self._get_column(c) for c in include]
        filters = dict((self._get_column(c), filters[c]) for c in filters)
        sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)

        return include, filters, sort

    def _get_column(self, column_name):
        """
        Returns the column on which an action (filtering, sorting, etc.) would need to be performed.
        Can be either an attribute of the class, or an association proxy linked to a relationship
        in the class.
        """
        column = getattr(self.model_cls, column_name)
        if column.is_attribute:
            return column
        else:
            # We need to get to the underlying attribute, so we move on to the
            # next remote_attr until we reach one
            while not column.remote_attr.is_attribute:
                column = column.remote_attr
            # Put a label on the remote attribute with the name of the column
            return column.remote_attr.label(column_name)

    @staticmethod
    def _paginate(query, pagination):
        """
        Paginates the query by size and offset.

        :param query: current SQLAlchemy query object
        :param pagination: optional dict with size and offset keys
        :return: tuple with four elements:
         * results: ``size`` items starting from ``offset``
         * the total count of items
         * ``size`` [default: 0]
         * ``offset`` [default: 0]
        """
        if pagination:
            size = pagination.get('size', 0)
            offset = pagination.get('offset', 0)
            total = query.order_by(None).count()  # Fastest way to count
            results = query.limit(size).offset(offset).all()
            return results, total, size, offset
        else:
            results = query.all()
            return results, len(results), 0, 0

    @staticmethod
    def _load_relationships(instance):
        """
        Helper method used to overcome a problem where the relationships that rely on joins aren't
        being loaded automatically.
        """
        for rel in instance.__mapper__.relationships:
            getattr(instance, rel.key)

    def _instrument(self, model):
        if self._instrumentation:
            return collection_instrumentation.instrument(self._instrumentation, model, self)
        else:
            return model


def init_storage(base_dir, filename='db.sqlite'):
    """
    Built-in ModelStorage initiator.

    Creates a SQLAlchemy engine and a session to be passed to the MAPI.

    ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the
    location of the database file, and an option filename. This would create an SQLite database.

    :param base_dir: directory of the database
    :param filename: database file name.
    :return:
    """
    uri = 'sqlite:///{platform_char}{path}'.format(
        # Handles the windows behavior where there is not root, but drivers.
        # Thus behaving as relative path.
        platform_char='' if 'Windows' in platform.system() else '/',

        path=os.path.join(base_dir, filename))

    engine = create_engine(uri, connect_args=dict(timeout=15))

    session_factory = orm.sessionmaker(bind=engine)
    session = orm.scoped_session(session_factory=session_factory)

    return dict(engine=engine, session=session)


class ListResult(list):
    """
    Contains results about the requested items.
    """
    def __init__(self, metadata, *args, **qwargs):
        super(ListResult, self).__init__(*args, **qwargs)
        self.metadata = metadata
        self.items = self