summaryrefslogtreecommitdiffstats
path: root/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py
diff options
context:
space:
mode:
Diffstat (limited to 'azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py')
-rw-r--r--azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py439
1 files changed, 439 insertions, 0 deletions
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py b/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py
new file mode 100644
index 0000000..975ada7
--- /dev/null
+++ b/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py
@@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+SQLAlchemy implementation of the storage model API ("MAPI").
+"""
+
+import os
+import platform
+
+from sqlalchemy import (
+ create_engine,
+ orm,
+)
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm.exc import StaleDataError
+
+from aria.utils.collections import OrderedDict
+from . import (
+ api,
+ exceptions,
+ collection_instrumentation
+)
+
+_predicates = {'ge': '__ge__',
+ 'gt': '__gt__',
+ 'lt': '__lt__',
+ 'le': '__le__',
+ 'eq': '__eq__',
+ 'ne': '__ne__'}
+
+
+class SQLAlchemyModelAPI(api.ModelAPI):
+ """
+ SQLAlchemy implementation of the storage model API ("MAPI").
+ """
+
+ def __init__(self,
+ engine,
+ session,
+ **kwargs):
+ super(SQLAlchemyModelAPI, self).__init__(**kwargs)
+ self._engine = engine
+ self._session = session
+
+ def get(self, entry_id, include=None, **kwargs):
+ """
+ Returns a single result based on the model class and element ID
+ """
+ query = self._get_query(include, {'id': entry_id})
+ result = query.first()
+
+ if not result:
+ raise exceptions.NotFoundError(
+ 'Requested `{0}` with ID `{1}` was not found'
+ .format(self.model_cls.__name__, entry_id)
+ )
+ return self._instrument(result)
+
+ def get_by_name(self, entry_name, include=None, **kwargs):
+ assert hasattr(self.model_cls, 'name')
+ result = self.list(include=include, filters={'name': entry_name})
+ if not result:
+ raise exceptions.NotFoundError(
+ 'Requested {0} with name `{1}` was not found'
+ .format(self.model_cls.__name__, entry_name)
+ )
+ elif len(result) > 1:
+ raise exceptions.StorageError(
+ 'Requested {0} with name `{1}` returned more than 1 value'
+ .format(self.model_cls.__name__, entry_name)
+ )
+ else:
+ return result[0]
+
+ def list(self,
+ include=None,
+ filters=None,
+ pagination=None,
+ sort=None,
+ **kwargs):
+ query = self._get_query(include, filters, sort)
+
+ results, total, size, offset = self._paginate(query, pagination)
+
+ return ListResult(
+ dict(total=total, size=size, offset=offset),
+ [self._instrument(result) for result in results]
+ )
+
+ def iter(self,
+ include=None,
+ filters=None,
+ sort=None,
+ **kwargs):
+ """
+ Returns a (possibly empty) list of ``model_class`` results.
+ """
+ for result in self._get_query(include, filters, sort):
+ yield self._instrument(result)
+
+ def put(self, entry, **kwargs):
+ """
+ Creatse a ``model_class`` instance from a serializable ``model`` object.
+
+ :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict``
+ method, and whose attributes match the columns of ``model_class`` (might also be just an
+ instance of ``model_class``)
+ :return: an instance of ``model_class``
+ """
+ self._session.add(entry)
+ self._safe_commit()
+ return entry
+
+ def delete(self, entry, **kwargs):
+ """
+ Deletes a single result based on the model class and element ID.
+ """
+ self._load_relationships(entry)
+ self._session.delete(entry)
+ self._safe_commit()
+ return entry
+
+ def update(self, entry, **kwargs):
+ """
+ Adds ``instance`` to the database session, and attempts to commit.
+
+ :return: updated instance
+ """
+ return self.put(entry)
+
+ def refresh(self, entry):
+ """
+ Reloads the instance with fresh information from the database.
+
+ :param entry: instance to be re-loaded from the database
+ :return: refreshed instance
+ """
+ self._session.refresh(entry)
+ self._load_relationships(entry)
+ return entry
+
+ def _destroy_connection(self):
+ pass
+
+ def _establish_connection(self):
+ pass
+
+ def create(self, checkfirst=True, create_all=True, **kwargs):
+ self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
+
+ if create_all:
+ # In order to create any models created dynamically (e.g. many-to-many helper tables are
+ # created at runtime).
+ self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)
+
+ def drop(self):
+ """
+ Drops the table.
+ """
+ self.model_cls.__table__.drop(self._engine)
+
+ def _safe_commit(self):
+ """
+ Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and
+ rolls back if they're caught.
+ """
+ try:
+ self._session.commit()
+ except StaleDataError as e:
+ self._session.rollback()
+ raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
+ except (SQLAlchemyError, ValueError) as e:
+ self._session.rollback()
+ raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
+
+ def _get_base_query(self, include, joins):
+ """
+ Create the initial query from the model class and included columns.
+
+ :param include: (possibly empty) list of columns to include in the query
+ :return: SQLAlchemy AppenderQuery object
+ """
+ # If only some columns are included, query through the session object
+ if include:
+ # Make sure that attributes come before association proxies
+ include.sort(key=lambda x: x.is_clause_element)
+ query = self._session.query(*include)
+ else:
+ # If all columns should be returned, query directly from the model
+ query = self._session.query(self.model_cls)
+
+ query = query.join(*joins)
+ return query
+
+ @staticmethod
+ def _get_joins(model_class, columns):
+ """
+ Gets a list of all the tables on which we need to join.
+
+ :param columns: set of all attributes involved in the query
+ """
+
+ # Using a list instead of a set because order is important
+ joins = OrderedDict()
+ for column_name in columns:
+ column = getattr(model_class, column_name)
+ while not column.is_attribute:
+ join_attr = column.local_attr
+ # This is a hack, to deal with the fact that SQLA doesn't
+ # fully support doing something like: `if join_attr in joins`,
+ # because some SQLA elements have their own comparators
+ join_attr_name = str(join_attr)
+ if join_attr_name not in joins:
+ joins[join_attr_name] = join_attr
+ column = column.remote_attr
+
+ return joins.values()
+
+ @staticmethod
+ def _sort_query(query, sort=None):
+ """
+ Adds sorting clauses to the query.
+
+ :param query: base SQL query
+ :param sort: optional dictionary where keys are column names to sort by, and values are
+ the order (asc/desc)
+ :return: SQLAlchemy AppenderQuery object
+ """
+ if sort:
+ for column, order in sort.items():
+ if order == 'desc':
+ column = column.desc()
+ query = query.order_by(column)
+ return query
+
+ def _filter_query(self, query, filters):
+ """
+ Adds filter clauses to the query.
+
+ :param query: base SQL query
+ :param filters: optional dictionary where keys are column names to filter by, and values
+ are values applicable for those columns (or lists of such values)
+ :return: SQLAlchemy AppenderQuery object
+ """
+ return self._add_value_filter(query, filters)
+
+ @staticmethod
+ def _add_value_filter(query, filters):
+ for column, value in filters.items():
+ if isinstance(value, dict):
+ for predicate, operand in value.items():
+ query = query.filter(getattr(column, predicate)(operand))
+ elif isinstance(value, (list, tuple)):
+ query = query.filter(column.in_(value))
+ else:
+ query = query.filter(column == value)
+
+ return query
+
+ def _get_query(self,
+ include=None,
+ filters=None,
+ sort=None):
+ """
+ Gets a SQL query object based on the params passed.
+
+ :param model_class: SQL database table class
+ :param include: optional list of columns to include in the query
+ :param filters: optional dictionary where keys are column names to filter by, and values
+ are values applicable for those columns (or lists of such values)
+ :param sort: optional dictionary where keys are column names to sort by, and values are the
+ order (asc/desc)
+ :return: sorted and filtered query with only the relevant columns
+ """
+ include, filters, sort, joins = self._get_joins_and_converted_columns(
+ include, filters, sort
+ )
+ filters = self._convert_operands(filters)
+
+ query = self._get_base_query(include, joins)
+ query = self._filter_query(query, filters)
+ query = self._sort_query(query, sort)
+ return query
+
+ @staticmethod
+ def _convert_operands(filters):
+ for column, conditions in filters.items():
+ if isinstance(conditions, dict):
+ for predicate, operand in conditions.items():
+ if predicate not in _predicates:
+ raise exceptions.StorageError(
+ "{0} is not a valid predicate for filtering. Valid predicates are {1}"
+ .format(predicate, ', '.join(_predicates.keys())))
+ del filters[column][predicate]
+ filters[column][_predicates[predicate]] = operand
+
+
+ return filters
+
+ def _get_joins_and_converted_columns(self,
+ include,
+ filters,
+ sort):
+ """
+ Gets a list of tables on which we need to join and the converted ``include``, ``filters``
+ and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of
+ column names).
+ """
+ include = include or []
+ filters = filters or dict()
+ sort = sort or OrderedDict()
+
+ all_columns = set(include) | set(filters.keys()) | set(sort.keys())
+ joins = self._get_joins(self.model_cls, all_columns)
+
+ include, filters, sort = self._get_columns_from_field_names(
+ include, filters, sort
+ )
+ return include, filters, sort, joins
+
+ def _get_columns_from_field_names(self,
+ include,
+ filters,
+ sort):
+ """
+ Gooes over the optional parameters (include, filters, sort), and replace column names with
+ actual SQLAlechmy column objects.
+ """
+ include = [self._get_column(c) for c in include]
+ filters = dict((self._get_column(c), filters[c]) for c in filters)
+ sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
+
+ return include, filters, sort
+
+ def _get_column(self, column_name):
+ """
+ Returns the column on which an action (filtering, sorting, etc.) would need to be performed.
+ Can be either an attribute of the class, or an association proxy linked to a relationship
+ in the class.
+ """
+ column = getattr(self.model_cls, column_name)
+ if column.is_attribute:
+ return column
+ else:
+ # We need to get to the underlying attribute, so we move on to the
+ # next remote_attr until we reach one
+ while not column.remote_attr.is_attribute:
+ column = column.remote_attr
+ # Put a label on the remote attribute with the name of the column
+ return column.remote_attr.label(column_name)
+
+ @staticmethod
+ def _paginate(query, pagination):
+ """
+ Paginates the query by size and offset.
+
+ :param query: current SQLAlchemy query object
+ :param pagination: optional dict with size and offset keys
+ :return: tuple with four elements:
+ * results: ``size`` items starting from ``offset``
+ * the total count of items
+ * ``size`` [default: 0]
+ * ``offset`` [default: 0]
+ """
+ if pagination:
+ size = pagination.get('size', 0)
+ offset = pagination.get('offset', 0)
+ total = query.order_by(None).count() # Fastest way to count
+ results = query.limit(size).offset(offset).all()
+ return results, total, size, offset
+ else:
+ results = query.all()
+ return results, len(results), 0, 0
+
+ @staticmethod
+ def _load_relationships(instance):
+ """
+ Helper method used to overcome a problem where the relationships that rely on joins aren't
+ being loaded automatically.
+ """
+ for rel in instance.__mapper__.relationships:
+ getattr(instance, rel.key)
+
+ def _instrument(self, model):
+ if self._instrumentation:
+ return collection_instrumentation.instrument(self._instrumentation, model, self)
+ else:
+ return model
+
+
+def init_storage(base_dir, filename='db.sqlite'):
+ """
+ Built-in ModelStorage initiator.
+
+ Creates a SQLAlchemy engine and a session to be passed to the MAPI.
+
+ ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the
+ location of the database file, and an option filename. This would create an SQLite database.
+
+ :param base_dir: directory of the database
+ :param filename: database file name.
+ :return:
+ """
+ uri = 'sqlite:///{platform_char}{path}'.format(
+ # Handles the windows behavior where there is not root, but drivers.
+ # Thus behaving as relative path.
+ platform_char='' if 'Windows' in platform.system() else '/',
+
+ path=os.path.join(base_dir, filename))
+
+ engine = create_engine(uri, connect_args=dict(timeout=15))
+
+ session_factory = orm.sessionmaker(bind=engine)
+ session = orm.scoped_session(session_factory=session_factory)
+
+ return dict(engine=engine, session=session)
+
+
+class ListResult(list):
+ """
+ Contains results about the requested items.
+ """
+ def __init__(self, metadata, *args, **qwargs):
+ super(ListResult, self).__init__(*args, **qwargs)
+ self.metadata = metadata
+ self.items = self