diff options
Diffstat (limited to 'scripts/api_squad.py')
-rw-r--r-- | scripts/api_squad.py | 45 |
1 files changed, 14 insertions, 31 deletions
diff --git a/scripts/api_squad.py b/scripts/api_squad.py index 239bbd6..f29a74b 100644 --- a/scripts/api_squad.py +++ b/scripts/api_squad.py @@ -1,6 +1,6 @@ # coding=utf-8 # squad interface -# Required parameters: +# Required parameters # FLAGS_output_dir :the output path of the model training during training process, the output of the trained model, etc.; the output path of the model prediction during predicting process # FLAGS_init_checkpoint_squad : model initialization path, use bert pre-trained model for training; use the output path during training for prediction # FLAGS_predict_file : the file to be predicted, csv file @@ -22,18 +22,13 @@ from __future__ import print_function import collections import json import math -import os -import random import modeling import optimization import tokenization import six import tensorflow as tf import pandas as pd -from global_setting import FLAGS_bert_config_file, FLAGS_vocab_file, FLAGS_init_checkpoint_squad - - - +from global_setting import FLAGS_init_checkpoint_squad FLAGS_max_seq_length = 512 FLAGS_do_lower_case = True @@ -53,11 +48,12 @@ FLAGS_warmup_proportion = 0.1 FLAGS_gcp_project = None FLAGS_null_score_diff_threshold = 0.0 -def make_json(input_file,questions): + +def make_json(input_file, questions): print(input_file) data_train = pd.read_excel(input_file) print(444) - data_train.fillna(0,inplace=True) + data_train.fillna(0, inplace=True) data_train.index = [i for i in range(len(data_train))] question = questions res = {} @@ -67,17 +63,16 @@ def make_json(input_file,questions): data_inside['title'] = 'Not available' data_inside['paragraphs'] = [] paragraphs_inside = {} - paragraphs_inside['context'] = data_train.loc[i,'text'] + paragraphs_inside['context'] = data_train.loc[i, 'text'] paragraphs_inside['qas'] = [] - for ques in question: + for ques in question: qas_inside = {} qas_inside['answers'] = [] - if data_train.loc[i,ques]: + if data_train.loc[i, ques]: answer_inside = {} - answer_inside['text'] = str(data_train.loc[i,ques]) + answer_inside['text'] = str(data_train.loc[i, ques]) answer_inside['answer_start'] = paragraphs_inside['context'].find(answer_inside['text']) qas_inside['is_impossible'] = 0 - else: qas_inside['is_impossible'] = 1 answer_inside = {} @@ -92,8 +87,6 @@ def make_json(input_file,questions): return json.dumps(res) - - class SquadExample(object): """A single training/test example for simple sequence classification. @@ -164,9 +157,9 @@ class InputFeatures(object): self.is_impossible = is_impossible -def read_squad_examples(input_file, is_training,questions,FLAGS_version_2_with_negative): +def read_squad_examples(input_file, is_training, questions, FLAGS_version_2_with_negative): """Read a SQuAD json file into a list of SquadExample.""" - data = make_json(input_file,questions) + data = make_json(input_file, questions) input_data = json.loads(data)["data"] def is_whitespace(c): @@ -212,8 +205,7 @@ def read_squad_examples(input_file, is_training,questions,FLAGS_version_2_with_n answer_offset = answer["answer_start"] answer_length = len(orig_answer_text) start_position = char_to_word_offset[answer_offset] - end_position = char_to_word_offset[answer_offset + answer_length - - 1] + end_position = char_to_word_offset[answer_offset + answer_length - 1] # Only add answers where the text can be exactly recovered from the # document. If this CAN'T happen it's likely due to weird Unicode # stuff so we will just skip the example. @@ -353,8 +345,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_start = doc_span.start doc_end = doc_span.start + doc_span.length - 1 out_of_span = False - if not (tok_start_position >= doc_start and - tok_end_position <= doc_end): + if not (tok_start_position >= doc_start and tok_end_position <= doc_end): out_of_span = True if out_of_span: start_position = 0 @@ -544,7 +535,6 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate, tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) - unique_ids = features["unique_ids"] input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] @@ -686,7 +676,7 @@ RawResult = collections.namedtuple("RawResult", def write_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, - output_nbest_file, output_null_log_odds_file): + output_nbest_file, output_null_log_odds_file, FLAGS_version_2_with_negative): """Write final predictions to the json file and log-odds of null if needed.""" tf.logging.info("Writing predictions to: %s" % (output_prediction_file)) tf.logging.info("Writing nbest to: %s" % (output_nbest_file)) @@ -705,7 +695,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() - scores_diff_json = collections.OrderedDict() for (example_index, example) in enumerate(all_examples): features = example_index_to_features[example_index] @@ -713,9 +702,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, prelim_predictions = [] # keep track of the minimum score of null start+end of position 0 score_null = 1000000 # large and positive - min_null_feature_index = 0 # the paragraph slice with min mull score - null_start_logit = 0 # the start logit at the slice with min null score - null_end_logit = 0 # the end logit at the slice with min null score for (feature_index, feature) in enumerate(features): result = unique_id_to_result[feature.unique_id] start_indexes = _get_best_indexes(result.start_logits, n_best_size) @@ -726,9 +712,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, result.end_logits[0] if feature_null_score < score_null: score_null = feature_null_score - min_null_feature_index = feature_index - null_start_logit = result.start_logits[0] - null_end_logit = result.end_logits[0] for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, e.g., predict |