diff options
Diffstat (limited to 'nlp/scripts/api_squad_offline.py')
-rw-r--r-- | nlp/scripts/api_squad_offline.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/nlp/scripts/api_squad_offline.py b/nlp/scripts/api_squad_offline.py index 8860dfe..a54ab7f 100644 --- a/nlp/scripts/api_squad_offline.py +++ b/nlp/scripts/api_squad_offline.py @@ -38,7 +38,7 @@ from api_squad import validate_flags_or_throw from api_squad import read_squad_examples from global_setting import CUDA_VISIBLE_DEVICES -from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad, questions +from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES) @@ -62,7 +62,7 @@ def serving_input_fn(): def main(FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file=None, FLAGS_train_file=None, FLAGS_do_predict=False, FLAGS_do_train=False, FLAGS_train_batch_size=16, FLAGS_predict_batch_size=8, FLAGS_learning_rate=5e-5, FLAGS_num_train_epochs=3.0, - FLAGS_max_answer_length=100, FLAGS_max_query_length=64, FLAGS_version_2_with_negative=False): + FLAGS_max_answer_length=100, FLAGS_max_query_length=64, FLAGS_version_2_with_negative=False,questions=[]]): tf.logging.set_verbosity(tf.logging.INFO) bert_config = modeling.BertConfig.from_json_file(FLAGS_bert_config_file) @@ -222,10 +222,11 @@ class AI2Flask: FLAGS_train_file = request.json['FLAGS_train_file'] FLAGS_export_dir = request.json['FLAGS_export_dir'] task_id = request.json['task_id'] + questions = request.json['questions'] task = p.submit(main, FLAGS_output_dir, FLAGS_init_checkpoint_squad, FLAGS_export_dir, FLAGS_predict_file, FLAGS_train_file, FLAGS_do_predict, FLAGS_do_train, FLAGS_train_batch_size, FLAGS_predict_batch_size, FLAGS_learning_rate, FLAGS_num_train_epochs, - FLAGS_max_answer_length, FLAGS_max_query_length, FLAGS_version_2_with_negative) + FLAGS_max_answer_length, FLAGS_max_query_length, FLAGS_version_2_with_negative,questions) threads_mapping[task_id] = task return jsonify({"message": "Task submitted successfully", "status": "0"}) |