summaryrefslogtreecommitdiffstats
path: root/runtime/model_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/model_api.py')
-rw-r--r--runtime/model_api.py215
1 files changed, 215 insertions, 0 deletions
diff --git a/runtime/model_api.py b/runtime/model_api.py
new file mode 100644
index 0000000..fd87333
--- /dev/null
+++ b/runtime/model_api.py
@@ -0,0 +1,215 @@
+# -------------------------------------------------------------------------
+# Copyright (c) 2020 AT&T Intellectual Property
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -------------------------------------------------------------------------
+#
+
+import json
+import traceback
+
+import mysql.connector
+from flask import g, Flask, Response
+
+from osdf.config.base import osdf_config
+from osdf.logging.osdf_logging import debug_log, error_log
+from osdf.operation.exceptions import BusinessException
+
+
+def init_db():
+ if is_db_enabled():
+ get_db()
+
+
+def get_db():
+ """Opens a new database connection if there is none yet for the
+ current application context.
+ """
+ if not hasattr(g, 'pg'):
+ properties = osdf_config['deployment']
+ host, db_port, db = properties["osdfDatabaseHost"], properties["osdfDatabasePort"], \
+ properties.get("osdfDatabaseSchema")
+ user, password = properties["osdfDatabaseUsername"], properties["osdfDatabasePassword"]
+ g.pg = mysql.connector.connect(host=host, port=db_port, user=user, password=password, database=db)
+ return g.pg
+
+
+def close_db():
+ """Closes the database again at the end of the request."""
+ if hasattr(g, 'pg'):
+ g.pg.close()
+
+
+app = Flask(__name__)
+
+
+def create_model_data(model_api):
+ with app.app_context():
+ try:
+ model_info = model_api['modelInfo']
+ model_id = model_info['modelId']
+ debug_log.debug(
+ "persisting model_api {}".format(model_id))
+ connection = get_db()
+ cursor = connection.cursor(buffered=True)
+ query = "SELECT model_id FROM optim_model_data WHERE model_id = %s"
+ values = (model_id,)
+ cursor.execute(query, values)
+ if cursor.fetchone() is None:
+ query = "INSERT INTO optim_model_data (model_id, model_content, description, solver_type) VALUES " \
+ "(%s, %s, %s, %s)"
+ values = (model_id, model_info['modelContent'], model_info.get('description'), model_info['solver'])
+ cursor.execute(query, values)
+ g.pg.commit()
+
+ debug_log.debug("A record successfully inserted for request_id: {}".format(model_id))
+ return retrieve_model_data(model_id)
+ close_db()
+ else:
+ query = "UPDATE optim_model_data SET model_content = %s, description = %s, solver_type = %s where " \
+ "model_id = %s "
+ values = (model_info['modelContent'], model_info.get('description'), model_info['solver'], model_id)
+ cursor.execute(query, values)
+ g.pg.commit()
+
+ return retrieve_model_data(model_id)
+ close_db()
+ except Exception as err:
+ error_log.error("error for request_id: {} - {}".format(model_id, traceback.format_exc()))
+ close_db()
+ raise BusinessException(err)
+
+
+def retrieve_model_data(model_id):
+ status, resp_data = get_model_data(model_id)
+
+ if status == 200:
+ resp = json.dumps(build_model_dict(resp_data))
+ return build_response(resp, status)
+ else:
+ resp = json.dumps({
+ 'modelId': model_id,
+ 'statusMessage': "Error retrieving the model data for model {} due to {}".format(model_id, resp_data)
+ })
+ return build_response(resp, status)
+
+
+def build_model_dict(resp_data, content_needed=True):
+ resp = {'modelId': resp_data[0], 'description': resp_data[2] if resp_data[2] else '',
+ 'solver': resp_data[3]}
+ if content_needed:
+ resp.update({'modelContent': resp_data[1]})
+ return resp
+
+
+def build_response(resp, status):
+ response = Response(resp, content_type='application/json; charset=utf-8')
+ response.headers.add('content-length', len(resp))
+ response.status_code = status
+ return response
+
+
+def delete_model_data(model_id):
+ with app.app_context():
+ try:
+ debug_log.debug("deleting model data given model_id = {}".format(model_id))
+ d = dict();
+ connection = get_db()
+ cursor = connection.cursor(buffered=True)
+ query = "delete from optim_model_data WHERE model_id = %s"
+ values = (model_id,)
+ cursor.execute(query, values)
+ g.pg.commit()
+ close_db()
+ resp = {
+ "statusMessage": "model data for modelId {} deleted".format(model_id)
+ }
+ return build_response(json.dumps(resp), 200)
+ except Exception as err:
+ error_log.error("error deleting model_id: {} - {}".format(model_id, traceback.format_exc()))
+ close_db()
+ raise BusinessException(err)
+
+
+def get_model_data(model_id):
+ with app.app_context():
+ try:
+ debug_log.debug("getting model data given model_id = {}".format(model_id))
+ d = dict();
+ connection = get_db()
+ cursor = connection.cursor(buffered=True)
+ query = "SELECT model_id, model_content, description, solver_type FROM optim_model_data WHERE model_id = %s"
+ values = (model_id,)
+ cursor.execute(query, values)
+ if cursor is None:
+ return 400, "FAILED"
+ else:
+ rows = cursor.fetchone()
+ if rows is not None:
+ index = 0
+ for row in rows:
+ d[index] = row
+ index = index + 1
+ return 200, d
+ else:
+ close_db()
+ return 500, "NOT_FOUND"
+ except Exception:
+ error_log.error("error for request_id: {} - {}".format(model_id, traceback.format_exc()))
+ close_db()
+ return 500, "FAILED"
+
+
+def retrieve_all_models():
+ status, resp_data = get_all_models()
+ model_list = []
+ if status == 200:
+ for r in resp_data:
+ model_list.append(build_model_dict(r, False))
+ resp = json.dumps(model_list)
+ return build_response(resp, status)
+
+ else:
+ resp = json.dumps({
+ 'statusMessage': "Error retrieving all the model data due to {}".format(resp_data)
+ })
+ return build_response(resp, status)
+
+
+def get_all_models():
+ with app.app_context():
+ try:
+ debug_log.debug("getting all model data".format())
+ connection = get_db()
+ cursor = connection.cursor(buffered=True)
+ query = "SELECT model_id, model_content, description, solver_type FROM optim_model_data"
+
+ cursor.execute(query)
+ if cursor is None:
+ return 400, "FAILED"
+ else:
+ rows = cursor.fetchall()
+ if rows is not None:
+ return 200, rows
+ else:
+ close_db()
+ return 500, "NOT_FOUND"
+ except Exception:
+ error_log.error("error for request_id: {}".format(traceback.format_exc()))
+ close_db()
+ return 500, "FAILED"
+
+
+def is_db_enabled():
+ return osdf_config['deployment'].get('isDatabaseEnabled', False)