summaryrefslogtreecommitdiffstats
path: root/scripts/api_squad_online.py
blob: 9cc6b086772fc1b7e4e2695f3797be7938540302 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
# coding: utf-8

# auther = 'liuzhiyong'
# date = 20201204



from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import datetime
import threading
import sys
from flask import Flask, abort, request, jsonify
from concurrent.futures import ThreadPoolExecutor

import collections
import math
import os
import random
import modeling
import optimization
import tokenization
import six
import tensorflow as tf
import pandas as pd
import numpy as np
import requests
from global_setting import *
from create_squad_features import get_squad_feature_result


app = Flask(__name__)


class AI2Flask:

    def __init__(self, port=5000,workers=4):
        self.app = app
        self.port = port


        
        @app.route('/api/online/predict', methods=['POST'])
        def text_analyse():
            if not request.json:
                abort(400)
           
            else:
                try:
                    try:
                        title = request.json['title']
                    except:
                        title = 'Not available'
                    text_origin = request.json['text']
                   

                    if len(text_origin) > 800:
                        text = text_origin[:800]
                    else:
                        text = text_origin

                    result = {}
                    for ques in questions:
                        tmp = get_squad_feature_result(title=title,text=text,tokenizer=tokenizer_ch,question=[ques],url='http://localhost:8502/v1/models/predict:predict')
                        result[ques] = dict(tmp)[ques]
                              
                    
                    print('finished!!')
                    return json.dumps(result)
 
            
                except KeyError as e:
                    return jsonify({"Des": 'KeyError: {}'.format(str(e)), "Result": 'None', "Status": "Error"})
                except Exception as e:
                    return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"})
        
        
    



        @app.route('/api/online/load', methods=['POST'])
        def load_model():
            if not request.json:
                abort(400)
            else:
                try:
                    path = request.json['path']
                    flag = os.system('./load_model.sh ' + path + ' ' + CUDA_VISIBLE_DEVICES)
                    if flag == 0:
                        return jsonify({"Des": "Model loaded successfully !", "Status": "OK"})
                    else:
                        return jsonify({"Des": "Model loaded failed , check the logs !", "Status": "Error"})
                except Exception as e:
                    return jsonify({"Des": str(e), "Status": "Error"})

    def start(self):
        self.app.run(host="0.0.0.0", port=self.port, threaded=True)


if __name__ == '__main__':
    port = sys.argv[1]
    AI2Flask(port=port).start()