summaryrefslogtreecommitdiffstats
path: root/scripts/api_squad_online.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/api_squad_online.py')
-rw-r--r--scripts/api_squad_online.py108
1 files changed, 108 insertions, 0 deletions
diff --git a/scripts/api_squad_online.py b/scripts/api_squad_online.py
new file mode 100644
index 0000000..9cc6b08
--- /dev/null
+++ b/scripts/api_squad_online.py
@@ -0,0 +1,108 @@
+
+#!/usr/bin/env python
+# coding: utf-8
+
+# auther = 'liuzhiyong'
+# date = 20201204
+
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import json
+import datetime
+import threading
+import sys
+from flask import Flask, abort, request, jsonify
+from concurrent.futures import ThreadPoolExecutor
+
+import collections
+import math
+import os
+import random
+import modeling
+import optimization
+import tokenization
+import six
+import tensorflow as tf
+import pandas as pd
+import numpy as np
+import requests
+from global_setting import *
+from create_squad_features import get_squad_feature_result
+
+
+app = Flask(__name__)
+
+
+class AI2Flask:
+
+ def __init__(self, port=5000,workers=4):
+ self.app = app
+ self.port = port
+
+
+
+ @app.route('/api/online/predict', methods=['POST'])
+ def text_analyse():
+ if not request.json:
+ abort(400)
+
+ else:
+ try:
+ try:
+ title = request.json['title']
+ except:
+ title = 'Not available'
+ text_origin = request.json['text']
+
+
+ if len(text_origin) > 800:
+ text = text_origin[:800]
+ else:
+ text = text_origin
+
+ result = {}
+ for ques in questions:
+ tmp = get_squad_feature_result(title=title,text=text,tokenizer=tokenizer_ch,question=[ques],url='http://localhost:8502/v1/models/predict:predict')
+ result[ques] = dict(tmp)[ques]
+
+
+ print('finished!!')
+ return json.dumps(result)
+
+
+ except KeyError as e:
+ return jsonify({"Des": 'KeyError: {}'.format(str(e)), "Result": 'None', "Status": "Error"})
+ except Exception as e:
+ return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"})
+
+
+
+
+
+
+ @app.route('/api/online/load', methods=['POST'])
+ def load_model():
+ if not request.json:
+ abort(400)
+ else:
+ try:
+ path = request.json['path']
+ flag = os.system('./load_model.sh ' + path + ' ' + CUDA_VISIBLE_DEVICES)
+ if flag == 0:
+ return jsonify({"Des": "Model loaded successfully !", "Status": "OK"})
+ else:
+ return jsonify({"Des": "Model loaded failed , check the logs !", "Status": "Error"})
+ except Exception as e:
+ return jsonify({"Des": str(e), "Status": "Error"})
+
+ def start(self):
+ self.app.run(host="0.0.0.0", port=self.port, threaded=True)
+
+
+if __name__ == '__main__':
+ port = sys.argv[1]
+ AI2Flask(port=port).start()
+