aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/api_squad_online.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/api_squad_online.py')
-rw-r--r--scripts/api_squad_online.py39
1 files changed, 6 insertions, 33 deletions
diff --git a/scripts/api_squad_online.py b/scripts/api_squad_online.py
index 9cc6b08..abe3d5f 100644
--- a/scripts/api_squad_online.py
+++ b/scripts/api_squad_online.py
@@ -1,4 +1,3 @@
-
#!/usr/bin/env python
# coding: utf-8
@@ -6,30 +5,15 @@
# date = 20201204
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
-import datetime
-import threading
import sys
from flask import Flask, abort, request, jsonify
-from concurrent.futures import ThreadPoolExecutor
-import collections
-import math
import os
-import random
-import modeling
-import optimization
-import tokenization
-import six
-import tensorflow as tf
-import pandas as pd
-import numpy as np
-import requests
-from global_setting import *
+from global_setting import questions, tokenizer_ch, CUDA_VISIBLE_DEVICES
from create_squad_features import get_squad_feature_result
@@ -38,17 +22,15 @@ app = Flask(__name__)
class AI2Flask:
- def __init__(self, port=5000,workers=4):
+ def __init__(self, port=5000, workers=4):
self.app = app
self.port = port
-
-
@app.route('/api/online/predict', methods=['POST'])
def text_analyse():
if not request.json:
abort(400)
-
+
else:
try:
try:
@@ -56,7 +38,6 @@ class AI2Flask:
except:
title = 'Not available'
text_origin = request.json['text']
-
if len(text_origin) > 800:
text = text_origin[:800]
@@ -65,23 +46,16 @@ class AI2Flask:
result = {}
for ques in questions:
- tmp = get_squad_feature_result(title=title,text=text,tokenizer=tokenizer_ch,question=[ques],url='http://localhost:8502/v1/models/predict:predict')
+ tmp = get_squad_feature_result(title=title, text=text, tokenizer=tokenizer_ch, question=[ques], url='http://localhost:8502/v1/models/predict:predict')
result[ques] = dict(tmp)[ques]
-
-
+
print('finished!!')
return json.dumps(result)
-
-
+
except KeyError as e:
return jsonify({"Des": 'KeyError: {}'.format(str(e)), "Result": 'None', "Status": "Error"})
except Exception as e:
return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"})
-
-
-
-
-
@app.route('/api/online/load', methods=['POST'])
def load_model():
@@ -105,4 +79,3 @@ class AI2Flask:
if __name__ == '__main__':
port = sys.argv[1]
AI2Flask(port=port).start()
-