summaryrefslogtreecommitdiffstats
path: root/nlp/scripts/api_squad_offline.py
diff options
context:
space:
mode:
Diffstat (limited to 'nlp/scripts/api_squad_offline.py')
-rw-r--r--nlp/scripts/api_squad_offline.py264
1 files changed, 264 insertions, 0 deletions
diff --git a/nlp/scripts/api_squad_offline.py b/nlp/scripts/api_squad_offline.py
new file mode 100644
index 0000000..8a05141
--- /dev/null
+++ b/nlp/scripts/api_squad_offline.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# auther = 'liuzhiyong'
+# date = 20201204
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from flask import Flask, abort, request, jsonify
+from concurrent.futures import ThreadPoolExecutor
+
+import os
+import random
+import modeling
+import tokenization
+import tensorflow as tf
+import sys
+
+from api_squad import FLAGS_max_seq_length
+from api_squad import FLAGS_do_lower_case
+from api_squad import FLAGS_use_tpu
+from api_squad import FLAGS_tpu_name
+from api_squad import FLAGS_tpu_zone
+from api_squad import FLAGS_gcp_project
+from api_squad import FLAGS_master
+from api_squad import FLAGS_save_checkpoints_steps
+from api_squad import FLAGS_iterations_per_loop
+from api_squad import FLAGS_num_tpu_cores
+from api_squad import FLAGS_warmup_proportion
+from api_squad import FLAGS_doc_stride
+from api_squad import model_fn_builder
+from api_squad import FeatureWriter
+from api_squad import convert_examples_to_features
+from api_squad import input_fn_builder
+
+from global_setting import CUDA_VISIBLE_DEVICES
+from global_setting import validate_flags_or_throw
+from global_setting import read_squad_examples
+from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad, questions
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES)
+
+app = Flask(__name__)
+
+
+def serving_input_fn():
+ input_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_ids')
+ unique_id = tf.placeholder(tf.int32, [None])
+ input_mask = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_mask')
+ segment_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='segment_ids')
+ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
+ 'input_ids': input_ids,
+ 'input_mask': input_mask,
+ 'segment_ids': segment_ids,
+ 'unique_ids': unique_id,
+ })()
+ return input_fn
+
+
+def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file=None, FLAGS_train_file=None, FLAGS_do_predict=False,
+ FLAGS_do_train=False, FLAGS_train_batch_size=16, FLAGS_predict_batch_size=8, FLAGS_learning_rate=5e-5, FLAGS_num_train_epochs=3.0,
+ FLAGS_max_answer_length=100, FLAGS_max_query_length=64, FLAGS_version_2_with_negative=False):
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ bert_config = modeling.BertConfig.from_json_file(FLAGS_bert_config_file)
+
+ validate_flags_or_throw(bert_config)
+
+ tf.gfile.MakeDirs(FLAGS_output_dir)
+
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS_vocab_file, do_lower_case=FLAGS_do_lower_case)
+
+ tpu_cluster_resolver = None
+ if FLAGS_use_tpu and FLAGS_tpu_name:
+ tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
+ FLAGS_tpu_name, zone=FLAGS_tpu_zone, project=FLAGS_gcp_project)
+ is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
+ run_config = tf.contrib.tpu.RunConfig(
+ cluster=tpu_cluster_resolver,
+ master=FLAGS_master,
+ model_dir=FLAGS_output_dir,
+ save_checkpoints_steps=FLAGS_save_checkpoints_steps,
+ tpu_config=tf.contrib.tpu.TPUConfig(
+ iterations_per_loop=FLAGS_iterations_per_loop,
+ num_shards=FLAGS_num_tpu_cores,
+ per_host_input_for_training=is_per_host))
+
+ train_examples = None
+ num_train_steps = None
+ num_warmup_steps = None
+
+ if FLAGS_do_train:
+ train_examples = read_squad_examples(
+ input_file=FLAGS_train_file, is_training=True, questions=questions, FLAGS_version_2_with_negative=FLAGS_version_2_with_negative)
+ num_train_steps = int(
+ len(train_examples) / FLAGS_train_batch_size * FLAGS_num_train_epochs)
+ num_warmup_steps = int(num_train_steps * FLAGS_warmup_proportion)
+
+ # Pre-shuffle the input to avoid having to make a very large shuffle
+ # buffer in in the `input_fn`.
+ rng = random.Random(12345)
+ rng.shuffle(train_examples)
+
+ model_fn = model_fn_builder(
+ bert_config=bert_config,
+ init_checkpoint=FLAGS_init_checkpoint_squad,
+ learning_rate=FLAGS_learning_rate,
+ num_train_steps=num_train_steps,
+ num_warmup_steps=num_warmup_steps,
+ use_tpu=FLAGS_use_tpu,
+ use_one_hot_embeddings=FLAGS_use_tpu)
+
+ # If TPU is not available, this will fall back to normal Estimator on CPU
+ # or GPU.
+ estimator = tf.contrib.tpu.TPUEstimator(
+ use_tpu=FLAGS_use_tpu,
+ model_fn=model_fn,
+ config=run_config,
+ train_batch_size=FLAGS_train_batch_size,
+ predict_batch_size=FLAGS_predict_batch_size)
+
+ if FLAGS_do_train:
+ # We write to a temporary file to avoid storing very large constant tensors
+ # in memory.
+ train_writer = FeatureWriter(
+ filename=os.path.join(FLAGS_output_dir, "train.tf_record"),
+ is_training=True)
+ convert_examples_to_features(
+ examples=train_examples,
+ tokenizer=tokenizer,
+ max_seq_length=FLAGS_max_seq_length,
+ doc_stride=FLAGS_doc_stride,
+ max_query_length=FLAGS_max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature)
+ train_writer.close()
+
+ tf.logging.info("***** Running training *****")
+ tf.logging.info(" Num orig examples = %d", len(train_examples))
+ tf.logging.info(" Num split examples = %d", train_writer.num_features)
+ tf.logging.info(" Batch size = %d", FLAGS_train_batch_size)
+ tf.logging.info(" Num steps = %d", num_train_steps)
+ del train_examples
+
+ train_input_fn = input_fn_builder(
+ input_file=train_writer.filename,
+ seq_length=FLAGS_max_seq_length,
+ is_training=True,
+ drop_remainder=True)
+ estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
+ estimator._export_to_tpu = False
+ estimator.export_savedmodel(FLAGS_export_dir, serving_input_fn)
+ return 'success'
+
+
+class AI2Flask:
+
+ def __init__(self, port=5000, workers=4):
+ self.app = app
+ self.port = port
+ p = ThreadPoolExecutor(max_workers=workers)
+ threads_mapping = {}
+
+ def check_threads():
+ flag = False
+ pop_keys = set()
+ if len(threads_mapping) >= workers:
+ for k, v in threads_mapping.items():
+ if v.running():
+ flag = True
+ else:
+ pop_keys.add(k)
+
+ for k in pop_keys:
+ threads_mapping.pop(k)
+
+ return flag
+
+ @app.route('/api/offline/train', methods=['POST'])
+ def text_analyse():
+ if not request.json or 'task_id' not in request.json:
+ abort(400)
+ if check_threads():
+ return jsonify({"Des": "Task list is full. Can not submit new task! ", "Result": "Failed to submit the training task ", "Status": "ERROR"})
+
+ else:
+ try:
+ FLAGS_train_batch_size = request.json['FLAGS_train_batch_size']
+ except:
+ FLAGS_train_batch_size = 16
+ try:
+ FLAGS_learning_rate = request.json['FLAGS_learning_rate']
+ except:
+ FLAGS_learning_rate = 5e-5
+ try:
+ FLAGS_num_train_epochs = request.json['FLAGS_num_train_epochs']
+ except:
+ FLAGS_num_train_epochs = 3.0
+ try:
+ FLAGS_max_answer_length = request.json['FLAGS_max_answer_length']
+ except:
+ FLAGS_max_answer_length = 100
+ try:
+ FLAGS_max_query_length = request.json['FLAGS_max_query_length']
+ except:
+ FLAGS_max_query_length = 64
+ try:
+ FLAGS_version_2_with_negative = request.json['FLAGS_version_2_with_negative']
+ except:
+ FLAGS_version_2_with_negative = True
+
+ try:
+ FLAGS_predict_file = None
+ FLAGS_predict_batch_size = 8
+ FLAGS_do_predict = False
+ FLAGS_do_train = True
+ FLAGS_output_dir = request.json['FLAGS_output_dir']
+ FLAGS_train_file = request.json['FLAGS_train_file']
+ FLAGS_export_dir = request.json['FLAGS_export_dir']
+ task_id = request.json['task_id']
+
+ task = p.submit(main, FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file, FLAGS_train_file, FLAGS_do_predict,
+ FLAGS_do_train, FLAGS_train_batch_size, FLAGS_predict_batch_size, FLAGS_learning_rate, FLAGS_num_train_epochs,
+ FLAGS_max_answer_length, FLAGS_max_query_length, FLAGS_version_2_with_negative)
+ threads_mapping[task_id] = task
+
+ return jsonify({"message": "Task submitted successfully", "status": "0"})
+
+ except KeyError as e:
+ return jsonify({"Des": 'KeyError: {}'.format(str(e)), "Result": 'None', "Status": "Error"})
+ except Exception as e:
+ return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"})
+
+ @app.route('/api/offline/status', methods=['POST'])
+ def todo_status():
+ task_id = request.json['task_id']
+ task = threads_mapping.get(task_id, None)
+ try:
+ if task is None:
+ return jsonify({'Des': 'The task was not found', 'Status': 'ERROR'})
+ else:
+ if task.done():
+ print(task.result)
+ if task.result() == 'success':
+ return jsonify({'Des': 'DONE', 'Status': 'OK'})
+ else:
+ return jsonify({'Des': 'Program execution error. Please check the execution log ', 'Status': 'ERROR'})
+
+ else:
+ return jsonify({'Des': 'RUNNING', 'Status': 'OK'})
+ except Exception as e:
+ return jsonify({'Des': str(e), 'Status': 'ERROR'})
+
+ def start(self):
+ self.app.run(host="0.0.0.0", port=self.port, threaded=True)
+
+
+if __name__ == '__main__':
+ port = sys.argv[1]
+ AI2Flask(port=port).start()