diff options
Diffstat (limited to 'azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py')
-rw-r--r-- | azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py | 439 |
1 files changed, 439 insertions, 0 deletions
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py b/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py new file mode 100644 index 0000000..975ada7 --- /dev/null +++ b/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py @@ -0,0 +1,439 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SQLAlchemy implementation of the storage model API ("MAPI"). +""" + +import os +import platform + +from sqlalchemy import ( + create_engine, + orm, +) +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm.exc import StaleDataError + +from aria.utils.collections import OrderedDict +from . import ( + api, + exceptions, + collection_instrumentation +) + +_predicates = {'ge': '__ge__', + 'gt': '__gt__', + 'lt': '__lt__', + 'le': '__le__', + 'eq': '__eq__', + 'ne': '__ne__'} + + +class SQLAlchemyModelAPI(api.ModelAPI): + """ + SQLAlchemy implementation of the storage model API ("MAPI"). + """ + + def __init__(self, + engine, + session, + **kwargs): + super(SQLAlchemyModelAPI, self).__init__(**kwargs) + self._engine = engine + self._session = session + + def get(self, entry_id, include=None, **kwargs): + """ + Returns a single result based on the model class and element ID + """ + query = self._get_query(include, {'id': entry_id}) + result = query.first() + + if not result: + raise exceptions.NotFoundError( + 'Requested `{0}` with ID `{1}` was not found' + .format(self.model_cls.__name__, entry_id) + ) + return self._instrument(result) + + def get_by_name(self, entry_name, include=None, **kwargs): + assert hasattr(self.model_cls, 'name') + result = self.list(include=include, filters={'name': entry_name}) + if not result: + raise exceptions.NotFoundError( + 'Requested {0} with name `{1}` was not found' + .format(self.model_cls.__name__, entry_name) + ) + elif len(result) > 1: + raise exceptions.StorageError( + 'Requested {0} with name `{1}` returned more than 1 value' + .format(self.model_cls.__name__, entry_name) + ) + else: + return result[0] + + def list(self, + include=None, + filters=None, + pagination=None, + sort=None, + **kwargs): + query = self._get_query(include, filters, sort) + + results, total, size, offset = self._paginate(query, pagination) + + return ListResult( + dict(total=total, size=size, offset=offset), + [self._instrument(result) for result in results] + ) + + def iter(self, + include=None, + filters=None, + sort=None, + **kwargs): + """ + Returns a (possibly empty) list of ``model_class`` results. + """ + for result in self._get_query(include, filters, sort): + yield self._instrument(result) + + def put(self, entry, **kwargs): + """ + Creatse a ``model_class`` instance from a serializable ``model`` object. + + :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict`` + method, and whose attributes match the columns of ``model_class`` (might also be just an + instance of ``model_class``) + :return: an instance of ``model_class`` + """ + self._session.add(entry) + self._safe_commit() + return entry + + def delete(self, entry, **kwargs): + """ + Deletes a single result based on the model class and element ID. + """ + self._load_relationships(entry) + self._session.delete(entry) + self._safe_commit() + return entry + + def update(self, entry, **kwargs): + """ + Adds ``instance`` to the database session, and attempts to commit. + + :return: updated instance + """ + return self.put(entry) + + def refresh(self, entry): + """ + Reloads the instance with fresh information from the database. + + :param entry: instance to be re-loaded from the database + :return: refreshed instance + """ + self._session.refresh(entry) + self._load_relationships(entry) + return entry + + def _destroy_connection(self): + pass + + def _establish_connection(self): + pass + + def create(self, checkfirst=True, create_all=True, **kwargs): + self.model_cls.__table__.create(self._engine, checkfirst=checkfirst) + + if create_all: + # In order to create any models created dynamically (e.g. many-to-many helper tables are + # created at runtime). + self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst) + + def drop(self): + """ + Drops the table. + """ + self.model_cls.__table__.drop(self._engine) + + def _safe_commit(self): + """ + Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and + rolls back if they're caught. + """ + try: + self._session.commit() + except StaleDataError as e: + self._session.rollback() + raise exceptions.StorageError('Version conflict: {0}'.format(str(e))) + except (SQLAlchemyError, ValueError) as e: + self._session.rollback() + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + + def _get_base_query(self, include, joins): + """ + Create the initial query from the model class and included columns. + + :param include: (possibly empty) list of columns to include in the query + :return: SQLAlchemy AppenderQuery object + """ + # If only some columns are included, query through the session object + if include: + # Make sure that attributes come before association proxies + include.sort(key=lambda x: x.is_clause_element) + query = self._session.query(*include) + else: + # If all columns should be returned, query directly from the model + query = self._session.query(self.model_cls) + + query = query.join(*joins) + return query + + @staticmethod + def _get_joins(model_class, columns): + """ + Gets a list of all the tables on which we need to join. + + :param columns: set of all attributes involved in the query + """ + + # Using a list instead of a set because order is important + joins = OrderedDict() + for column_name in columns: + column = getattr(model_class, column_name) + while not column.is_attribute: + join_attr = column.local_attr + # This is a hack, to deal with the fact that SQLA doesn't + # fully support doing something like: `if join_attr in joins`, + # because some SQLA elements have their own comparators + join_attr_name = str(join_attr) + if join_attr_name not in joins: + joins[join_attr_name] = join_attr + column = column.remote_attr + + return joins.values() + + @staticmethod + def _sort_query(query, sort=None): + """ + Adds sorting clauses to the query. + + :param query: base SQL query + :param sort: optional dictionary where keys are column names to sort by, and values are + the order (asc/desc) + :return: SQLAlchemy AppenderQuery object + """ + if sort: + for column, order in sort.items(): + if order == 'desc': + column = column.desc() + query = query.order_by(column) + return query + + def _filter_query(self, query, filters): + """ + Adds filter clauses to the query. + + :param query: base SQL query + :param filters: optional dictionary where keys are column names to filter by, and values + are values applicable for those columns (or lists of such values) + :return: SQLAlchemy AppenderQuery object + """ + return self._add_value_filter(query, filters) + + @staticmethod + def _add_value_filter(query, filters): + for column, value in filters.items(): + if isinstance(value, dict): + for predicate, operand in value.items(): + query = query.filter(getattr(column, predicate)(operand)) + elif isinstance(value, (list, tuple)): + query = query.filter(column.in_(value)) + else: + query = query.filter(column == value) + + return query + + def _get_query(self, + include=None, + filters=None, + sort=None): + """ + Gets a SQL query object based on the params passed. + + :param model_class: SQL database table class + :param include: optional list of columns to include in the query + :param filters: optional dictionary where keys are column names to filter by, and values + are values applicable for those columns (or lists of such values) + :param sort: optional dictionary where keys are column names to sort by, and values are the + order (asc/desc) + :return: sorted and filtered query with only the relevant columns + """ + include, filters, sort, joins = self._get_joins_and_converted_columns( + include, filters, sort + ) + filters = self._convert_operands(filters) + + query = self._get_base_query(include, joins) + query = self._filter_query(query, filters) + query = self._sort_query(query, sort) + return query + + @staticmethod + def _convert_operands(filters): + for column, conditions in filters.items(): + if isinstance(conditions, dict): + for predicate, operand in conditions.items(): + if predicate not in _predicates: + raise exceptions.StorageError( + "{0} is not a valid predicate for filtering. Valid predicates are {1}" + .format(predicate, ', '.join(_predicates.keys()))) + del filters[column][predicate] + filters[column][_predicates[predicate]] = operand + + + return filters + + def _get_joins_and_converted_columns(self, + include, + filters, + sort): + """ + Gets a list of tables on which we need to join and the converted ``include``, ``filters`` + and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of + column names). + """ + include = include or [] + filters = filters or dict() + sort = sort or OrderedDict() + + all_columns = set(include) | set(filters.keys()) | set(sort.keys()) + joins = self._get_joins(self.model_cls, all_columns) + + include, filters, sort = self._get_columns_from_field_names( + include, filters, sort + ) + return include, filters, sort, joins + + def _get_columns_from_field_names(self, + include, + filters, + sort): + """ + Gooes over the optional parameters (include, filters, sort), and replace column names with + actual SQLAlechmy column objects. + """ + include = [self._get_column(c) for c in include] + filters = dict((self._get_column(c), filters[c]) for c in filters) + sort = OrderedDict((self._get_column(c), sort[c]) for c in sort) + + return include, filters, sort + + def _get_column(self, column_name): + """ + Returns the column on which an action (filtering, sorting, etc.) would need to be performed. + Can be either an attribute of the class, or an association proxy linked to a relationship + in the class. + """ + column = getattr(self.model_cls, column_name) + if column.is_attribute: + return column + else: + # We need to get to the underlying attribute, so we move on to the + # next remote_attr until we reach one + while not column.remote_attr.is_attribute: + column = column.remote_attr + # Put a label on the remote attribute with the name of the column + return column.remote_attr.label(column_name) + + @staticmethod + def _paginate(query, pagination): + """ + Paginates the query by size and offset. + + :param query: current SQLAlchemy query object + :param pagination: optional dict with size and offset keys + :return: tuple with four elements: + * results: ``size`` items starting from ``offset`` + * the total count of items + * ``size`` [default: 0] + * ``offset`` [default: 0] + """ + if pagination: + size = pagination.get('size', 0) + offset = pagination.get('offset', 0) + total = query.order_by(None).count() # Fastest way to count + results = query.limit(size).offset(offset).all() + return results, total, size, offset + else: + results = query.all() + return results, len(results), 0, 0 + + @staticmethod + def _load_relationships(instance): + """ + Helper method used to overcome a problem where the relationships that rely on joins aren't + being loaded automatically. + """ + for rel in instance.__mapper__.relationships: + getattr(instance, rel.key) + + def _instrument(self, model): + if self._instrumentation: + return collection_instrumentation.instrument(self._instrumentation, model, self) + else: + return model + + +def init_storage(base_dir, filename='db.sqlite'): + """ + Built-in ModelStorage initiator. + + Creates a SQLAlchemy engine and a session to be passed to the MAPI. + + ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the + location of the database file, and an option filename. This would create an SQLite database. + + :param base_dir: directory of the database + :param filename: database file name. + :return: + """ + uri = 'sqlite:///{platform_char}{path}'.format( + # Handles the windows behavior where there is not root, but drivers. + # Thus behaving as relative path. + platform_char='' if 'Windows' in platform.system() else '/', + + path=os.path.join(base_dir, filename)) + + engine = create_engine(uri, connect_args=dict(timeout=15)) + + session_factory = orm.sessionmaker(bind=engine) + session = orm.scoped_session(session_factory=session_factory) + + return dict(engine=engine, session=session) + + +class ListResult(list): + """ + Contains results about the requested items. + """ + def __init__(self, metadata, *args, **qwargs): + super(ListResult, self).__init__(*args, **qwargs) + self.metadata = metadata + self.items = self |