1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
#!/usr/bin/env python
# coding: utf-8
# auther = 'liuzhiyong'
# date = 20201204
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import sys
from flask import Flask, abort, request, jsonify
import os
from global_setting import questions, tokenizer_ch, CUDA_VISIBLE_DEVICES
from create_squad_features import get_squad_feature_result
app = Flask(__name__)
class AI2Flask:
def __init__(self, port=5000, workers=4):
self.app = app
self.port = port
@app.route('/api/online/predict', methods=['POST'])
def text_analyse():
if not request.json:
abort(400)
else:
try:
try:
title = request.json['title']
except:
title = 'Not available'
text_origin = request.json['text']
if len(text_origin) > 800:
text = text_origin[:800]
else:
text = text_origin
result = {}
for ques in questions:
tmp = get_squad_feature_result(title=title, text=text, tokenizer=tokenizer_ch, question=[ques], url='http://localhost:8502/v1/models/predict:predict')
result[ques] = dict(tmp)[ques]
print('finished!!')
return json.dumps(result)
except KeyError as e:
return jsonify({"Des": 'KeyError: {}'.format(str(e)), "Result": 'None', "Status": "Error"})
except Exception as e:
return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"})
@app.route('/api/online/load', methods=['POST'])
def load_model():
if not request.json:
abort(400)
else:
try:
path = request.json['path']
flag = os.system('./load_model.sh ' + path + ' ' + CUDA_VISIBLE_DEVICES)
if flag == 0:
return jsonify({"Des": "Model loaded successfully !", "Status": "OK"})
else:
return jsonify({"Des": "Model loaded failed , check the logs !", "Status": "Error"})
except Exception as e:
return jsonify({"Des": str(e), "Status": "Error"})
def start(self):
self.app.run(host="0.0.0.0", port=self.port, threaded=True)
if __name__ == '__main__':
port = sys.argv[1]
AI2Flask(port=port).start()
|