diff options
Diffstat (limited to 'nlp/scripts/api_squad_offline.py')
-rw-r--r-- | nlp/scripts/api_squad_offline.py | 264 |
1 files changed, 264 insertions, 0 deletions
diff --git a/nlp/scripts/api_squad_offline.py b/nlp/scripts/api_squad_offline.py new file mode 100644 index 0000000..8a05141 --- /dev/null +++ b/nlp/scripts/api_squad_offline.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python +# coding: utf-8 + +# auther = 'liuzhiyong' +# date = 20201204 + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from flask import Flask, abort, request, jsonify +from concurrent.futures import ThreadPoolExecutor + +import os +import random +import modeling +import tokenization +import tensorflow as tf +import sys + +from api_squad import FLAGS_max_seq_length +from api_squad import FLAGS_do_lower_case +from api_squad import FLAGS_use_tpu +from api_squad import FLAGS_tpu_name +from api_squad import FLAGS_tpu_zone +from api_squad import FLAGS_gcp_project +from api_squad import FLAGS_master +from api_squad import FLAGS_save_checkpoints_steps +from api_squad import FLAGS_iterations_per_loop +from api_squad import FLAGS_num_tpu_cores +from api_squad import FLAGS_warmup_proportion +from api_squad import FLAGS_doc_stride +from api_squad import model_fn_builder +from api_squad import FeatureWriter +from api_squad import convert_examples_to_features +from api_squad import input_fn_builder + +from global_setting import CUDA_VISIBLE_DEVICES +from global_setting import validate_flags_or_throw +from global_setting import read_squad_examples +from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad, questions + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES) + +app = Flask(__name__) + + +def serving_input_fn(): + input_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_ids') + unique_id = tf.placeholder(tf.int32, [None]) + input_mask = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_mask') + segment_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='segment_ids') + input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ + 'input_ids': input_ids, + 'input_mask': input_mask, + 'segment_ids': segment_ids, + 'unique_ids': unique_id, + })() + return input_fn + + +def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file=None, FLAGS_train_file=None, FLAGS_do_predict=False, + FLAGS_do_train=False, FLAGS_train_batch_size=16, FLAGS_predict_batch_size=8, FLAGS_learning_rate=5e-5, FLAGS_num_train_epochs=3.0, + FLAGS_max_answer_length=100, FLAGS_max_query_length=64, FLAGS_version_2_with_negative=False): + tf.logging.set_verbosity(tf.logging.INFO) + + bert_config = modeling.BertConfig.from_json_file(FLAGS_bert_config_file) + + validate_flags_or_throw(bert_config) + + tf.gfile.MakeDirs(FLAGS_output_dir) + + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS_vocab_file, do_lower_case=FLAGS_do_lower_case) + + tpu_cluster_resolver = None + if FLAGS_use_tpu and FLAGS_tpu_name: + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + FLAGS_tpu_name, zone=FLAGS_tpu_zone, project=FLAGS_gcp_project) + is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 + run_config = tf.contrib.tpu.RunConfig( + cluster=tpu_cluster_resolver, + master=FLAGS_master, + model_dir=FLAGS_output_dir, + save_checkpoints_steps=FLAGS_save_checkpoints_steps, + tpu_config=tf.contrib.tpu.TPUConfig( + iterations_per_loop=FLAGS_iterations_per_loop, + num_shards=FLAGS_num_tpu_cores, + per_host_input_for_training=is_per_host)) + + train_examples = None + num_train_steps = None + num_warmup_steps = None + + if FLAGS_do_train: + train_examples = read_squad_examples( + input_file=FLAGS_train_file, is_training=True, questions=questions, FLAGS_version_2_with_negative=FLAGS_version_2_with_negative) + num_train_steps = int( + len(train_examples) / FLAGS_train_batch_size * FLAGS_num_train_epochs) + num_warmup_steps = int(num_train_steps * FLAGS_warmup_proportion) + + # Pre-shuffle the input to avoid having to make a very large shuffle + # buffer in in the `input_fn`. + rng = random.Random(12345) + rng.shuffle(train_examples) + + model_fn = model_fn_builder( + bert_config=bert_config, + init_checkpoint=FLAGS_init_checkpoint_squad, + learning_rate=FLAGS_learning_rate, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps, + use_tpu=FLAGS_use_tpu, + use_one_hot_embeddings=FLAGS_use_tpu) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = tf.contrib.tpu.TPUEstimator( + use_tpu=FLAGS_use_tpu, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS_train_batch_size, + predict_batch_size=FLAGS_predict_batch_size) + + if FLAGS_do_train: + # We write to a temporary file to avoid storing very large constant tensors + # in memory. + train_writer = FeatureWriter( + filename=os.path.join(FLAGS_output_dir, "train.tf_record"), + is_training=True) + convert_examples_to_features( + examples=train_examples, + tokenizer=tokenizer, + max_seq_length=FLAGS_max_seq_length, + doc_stride=FLAGS_doc_stride, + max_query_length=FLAGS_max_query_length, + is_training=True, + output_fn=train_writer.process_feature) + train_writer.close() + + tf.logging.info("***** Running training *****") + tf.logging.info(" Num orig examples = %d", len(train_examples)) + tf.logging.info(" Num split examples = %d", train_writer.num_features) + tf.logging.info(" Batch size = %d", FLAGS_train_batch_size) + tf.logging.info(" Num steps = %d", num_train_steps) + del train_examples + + train_input_fn = input_fn_builder( + input_file=train_writer.filename, + seq_length=FLAGS_max_seq_length, + is_training=True, + drop_remainder=True) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + estimator._export_to_tpu = False + estimator.export_savedmodel(FLAGS_export_dir, serving_input_fn) + return 'success' + + +class AI2Flask: + + def __init__(self, port=5000, workers=4): + self.app = app + self.port = port + p = ThreadPoolExecutor(max_workers=workers) + threads_mapping = {} + + def check_threads(): + flag = False + pop_keys = set() + if len(threads_mapping) >= workers: + for k, v in threads_mapping.items(): + if v.running(): + flag = True + else: + pop_keys.add(k) + + for k in pop_keys: + threads_mapping.pop(k) + + return flag + + @app.route('/api/offline/train', methods=['POST']) + def text_analyse(): + if not request.json or 'task_id' not in request.json: + abort(400) + if check_threads(): + return jsonify({"Des": "Task list is full. Can not submit new task! ", "Result": "Failed to submit the training task ", "Status": "ERROR"}) + + else: + try: + FLAGS_train_batch_size = request.json['FLAGS_train_batch_size'] + except: + FLAGS_train_batch_size = 16 + try: + FLAGS_learning_rate = request.json['FLAGS_learning_rate'] + except: + FLAGS_learning_rate = 5e-5 + try: + FLAGS_num_train_epochs = request.json['FLAGS_num_train_epochs'] + except: + FLAGS_num_train_epochs = 3.0 + try: + FLAGS_max_answer_length = request.json['FLAGS_max_answer_length'] + except: + FLAGS_max_answer_length = 100 + try: + FLAGS_max_query_length = request.json['FLAGS_max_query_length'] + except: + FLAGS_max_query_length = 64 + try: + FLAGS_version_2_with_negative = request.json['FLAGS_version_2_with_negative'] + except: + FLAGS_version_2_with_negative = True + + try: + FLAGS_predict_file = None + FLAGS_predict_batch_size = 8 + FLAGS_do_predict = False + FLAGS_do_train = True + FLAGS_output_dir = request.json['FLAGS_output_dir'] + FLAGS_train_file = request.json['FLAGS_train_file'] + FLAGS_export_dir = request.json['FLAGS_export_dir'] + task_id = request.json['task_id'] + + task = p.submit(main, FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file, FLAGS_train_file, FLAGS_do_predict, + FLAGS_do_train, FLAGS_train_batch_size, FLAGS_predict_batch_size, FLAGS_learning_rate, FLAGS_num_train_epochs, + FLAGS_max_answer_length, FLAGS_max_query_length, FLAGS_version_2_with_negative) + threads_mapping[task_id] = task + + return jsonify({"message": "Task submitted successfully", "status": "0"}) + + except KeyError as e: + return jsonify({"Des": 'KeyError: {}'.format(str(e)), "Result": 'None', "Status": "Error"}) + except Exception as e: + return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"}) + + @app.route('/api/offline/status', methods=['POST']) + def todo_status(): + task_id = request.json['task_id'] + task = threads_mapping.get(task_id, None) + try: + if task is None: + return jsonify({'Des': 'The task was not found', 'Status': 'ERROR'}) + else: + if task.done(): + print(task.result) + if task.result() == 'success': + return jsonify({'Des': 'DONE', 'Status': 'OK'}) + else: + return jsonify({'Des': 'Program execution error. Please check the execution log ', 'Status': 'ERROR'}) + + else: + return jsonify({'Des': 'RUNNING', 'Status': 'OK'}) + except Exception as e: + return jsonify({'Des': str(e), 'Status': 'ERROR'}) + + def start(self): + self.app.run(host="0.0.0.0", port=self.port, threaded=True) + + +if __name__ == '__main__': + port = sys.argv[1] + AI2Flask(port=port).start() |