aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/api_squad_offline.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/api_squad_offline.py')
-rw-r--r--scripts/api_squad_offline.py48
1 files changed, 27 insertions, 21 deletions
diff --git a/scripts/api_squad_offline.py b/scripts/api_squad_offline.py
index 1c98a10..8a05141 100644
--- a/scripts/api_squad_offline.py
+++ b/scripts/api_squad_offline.py
@@ -1,4 +1,3 @@
-
#!/usr/bin/env python
# coding: utf-8
@@ -9,25 +8,36 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
-import datetime
-import threading
-import time
from flask import Flask, abort, request, jsonify
from concurrent.futures import ThreadPoolExecutor
-import collections
-import math
import os
import random
import modeling
-import optimization
import tokenization
-import six
import tensorflow as tf
import sys
-from api_squad import *
-from global_setting import *
+
+from api_squad import FLAGS_max_seq_length
+from api_squad import FLAGS_do_lower_case
+from api_squad import FLAGS_use_tpu
+from api_squad import FLAGS_tpu_name
+from api_squad import FLAGS_tpu_zone
+from api_squad import FLAGS_gcp_project
+from api_squad import FLAGS_master
+from api_squad import FLAGS_save_checkpoints_steps
+from api_squad import FLAGS_iterations_per_loop
+from api_squad import FLAGS_num_tpu_cores
+from api_squad import FLAGS_warmup_proportion
+from api_squad import FLAGS_doc_stride
+from api_squad import model_fn_builder
+from api_squad import FeatureWriter
+from api_squad import convert_examples_to_features
+from api_squad import input_fn_builder
+
+from global_setting import CUDA_VISIBLE_DEVICES
+from global_setting import validate_flags_or_throw
+from global_setting import read_squad_examples
from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad, questions
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@@ -35,9 +45,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES)
app = Flask(__name__)
+
def serving_input_fn():
input_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_ids')
- unique_id = tf.placeholder(tf.int32,[None])
+ unique_id = tf.placeholder(tf.int32, [None])
input_mask = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
@@ -45,13 +56,13 @@ def serving_input_fn():
'input_mask': input_mask,
'segment_ids': segment_ids,
'unique_ids': unique_id,
- })()
+ })()
return input_fn
+
def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file=None, FLAGS_train_file=None, FLAGS_do_predict=False,
FLAGS_do_train=False, FLAGS_train_batch_size=16, FLAGS_predict_batch_size=8, FLAGS_learning_rate=5e-5, FLAGS_num_train_epochs=3.0,
FLAGS_max_answer_length=100, FLAGS_max_query_length=64, FLAGS_version_2_with_negative=False):
-
tf.logging.set_verbosity(tf.logging.INFO)
bert_config = modeling.BertConfig.from_json_file(FLAGS_bert_config_file)
@@ -60,7 +71,6 @@ def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_
tf.gfile.MakeDirs(FLAGS_output_dir)
-
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS_vocab_file, do_lower_case=FLAGS_do_lower_case)
@@ -68,7 +78,6 @@ def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_
if FLAGS_use_tpu and FLAGS_tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS_tpu_name, zone=FLAGS_tpu_zone, project=FLAGS_gcp_project)
-
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
@@ -86,8 +95,7 @@ def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_
if FLAGS_do_train:
train_examples = read_squad_examples(
- input_file=FLAGS_train_file, is_training=True,questions = questions,FLAGS_version_2_with_negative = FLAGS_version_2_with_negative)
-
+ input_file=FLAGS_train_file, is_training=True, questions=questions, FLAGS_version_2_with_negative=FLAGS_version_2_with_negative)
num_train_steps = int(
len(train_examples) / FLAGS_train_batch_size * FLAGS_num_train_epochs)
num_warmup_steps = int(num_train_steps * FLAGS_warmup_proportion)
@@ -174,7 +182,7 @@ class AI2Flask:
@app.route('/api/offline/train', methods=['POST'])
def text_analyse():
- if not request.json or not 'task_id' in request.json:
+ if not request.json or 'task_id' not in request.json:
abort(400)
if check_threads():
return jsonify({"Des": "Task list is full. Can not submit new task! ", "Result": "Failed to submit the training task ", "Status": "ERROR"})
@@ -227,8 +235,6 @@ class AI2Flask:
except Exception as e:
return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"})
-
-
@app.route('/api/offline/status', methods=['POST'])
def todo_status():
task_id = request.json['task_id']