diff options
Diffstat (limited to 'scripts/api_squad_offline.py')
-rw-r--r-- | scripts/api_squad_offline.py | 48 |
1 files changed, 27 insertions, 21 deletions
diff --git a/scripts/api_squad_offline.py b/scripts/api_squad_offline.py index 1c98a10..8a05141 100644 --- a/scripts/api_squad_offline.py +++ b/scripts/api_squad_offline.py @@ -1,4 +1,3 @@ - #!/usr/bin/env python # coding: utf-8 @@ -9,25 +8,36 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json -import datetime -import threading -import time from flask import Flask, abort, request, jsonify from concurrent.futures import ThreadPoolExecutor -import collections -import math import os import random import modeling -import optimization import tokenization -import six import tensorflow as tf import sys -from api_squad import * -from global_setting import * + +from api_squad import FLAGS_max_seq_length +from api_squad import FLAGS_do_lower_case +from api_squad import FLAGS_use_tpu +from api_squad import FLAGS_tpu_name +from api_squad import FLAGS_tpu_zone +from api_squad import FLAGS_gcp_project +from api_squad import FLAGS_master +from api_squad import FLAGS_save_checkpoints_steps +from api_squad import FLAGS_iterations_per_loop +from api_squad import FLAGS_num_tpu_cores +from api_squad import FLAGS_warmup_proportion +from api_squad import FLAGS_doc_stride +from api_squad import model_fn_builder +from api_squad import FeatureWriter +from api_squad import convert_examples_to_features +from api_squad import input_fn_builder + +from global_setting import CUDA_VISIBLE_DEVICES +from global_setting import validate_flags_or_throw +from global_setting import read_squad_examples from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad, questions os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -35,9 +45,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES) app = Flask(__name__) + def serving_input_fn(): input_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_ids') - unique_id = tf.placeholder(tf.int32,[None]) + unique_id = tf.placeholder(tf.int32, [None]) input_mask = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='input_mask') segment_ids = tf.placeholder(tf.int32, [None, FLAGS_max_seq_length], name='segment_ids') input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ @@ -45,13 +56,13 @@ def serving_input_fn(): 'input_mask': input_mask, 'segment_ids': segment_ids, 'unique_ids': unique_id, - })() + })() return input_fn + def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file=None, FLAGS_train_file=None, FLAGS_do_predict=False, FLAGS_do_train=False, FLAGS_train_batch_size=16, FLAGS_predict_batch_size=8, FLAGS_learning_rate=5e-5, FLAGS_num_train_epochs=3.0, FLAGS_max_answer_length=100, FLAGS_max_query_length=64, FLAGS_version_2_with_negative=False): - tf.logging.set_verbosity(tf.logging.INFO) bert_config = modeling.BertConfig.from_json_file(FLAGS_bert_config_file) @@ -60,7 +71,6 @@ def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_ tf.gfile.MakeDirs(FLAGS_output_dir) - tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS_vocab_file, do_lower_case=FLAGS_do_lower_case) @@ -68,7 +78,6 @@ def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_ if FLAGS_use_tpu and FLAGS_tpu_name: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS_tpu_name, zone=FLAGS_tpu_zone, project=FLAGS_gcp_project) - is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, @@ -86,8 +95,7 @@ def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_ if FLAGS_do_train: train_examples = read_squad_examples( - input_file=FLAGS_train_file, is_training=True,questions = questions,FLAGS_version_2_with_negative = FLAGS_version_2_with_negative) - + input_file=FLAGS_train_file, is_training=True, questions=questions, FLAGS_version_2_with_negative=FLAGS_version_2_with_negative) num_train_steps = int( len(train_examples) / FLAGS_train_batch_size * FLAGS_num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS_warmup_proportion) @@ -174,7 +182,7 @@ class AI2Flask: @app.route('/api/offline/train', methods=['POST']) def text_analyse(): - if not request.json or not 'task_id' in request.json: + if not request.json or 'task_id' not in request.json: abort(400) if check_threads(): return jsonify({"Des": "Task list is full. Can not submit new task! ", "Result": "Failed to submit the training task ", "Status": "ERROR"}) @@ -227,8 +235,6 @@ class AI2Flask: except Exception as e: return jsonify({"Des": str(e), "Result": 'None', "Status": "Error"}) - - @app.route('/api/offline/status', methods=['POST']) def todo_status(): task_id = request.json['task_id'] |