import tensorflow as tf import numpy as np import codecs import pickle import os from datetime import datetime from bert_base.train.model_utlis import create_model, InputFeatures from bert_base.bert import tokenization, modeling flags = tf.flags FLAGS = flags.FLAGS #输入输出地址 flags.DEFINE_string('data_dir', 'data', '数据集地址') flags.DEFINE_string('output_dir', 'output', '输出地址') #Bert相关参数 flags.DEFINE_string('bert_config_file', 'chinese_L-12_H-768_A-12/bert_config.json', 'Bert配置文件') flags.DEFINE_string('vocab_file', 'chinese_L-12_H-768_A-12/vocab.txt','vocab_file') flags.DEFINE_string('init_checkpoint','chinese_L-12_H-768_A-12/bert_model.ckpt', 'init_checkpoint') #训练和校验的相关参数 flags.DEFINE_bool('do_train', False, '是否开始训练') flags.DEFINE_bool('do_dev', False, '是否开始校验') flags.DEFINE_bool('do_test', True, '是否开始测试') flags.DEFINE_bool('do_lower_case', True, '是否转换小写') #模型相关的 flags.DEFINE_integer('lstm_size', 128, 'lstm_size') flags.DEFINE_integer('num_layers', 1, 'num_layers') flags.DEFINE_integer('max_seq_length', 128, 'max_seq_length') flags.DEFINE_integer('train_batch_size', 64, 'train_batch_size') flags.DEFINE_integer('dev_batch_size',64, 'dev_batch_size') flags.DEFINE_integer('test_batch_size', 32, 'test_batch_size') flags.DEFINE_integer('save_checkpoints_steps', 500, 'save_checkpoints_steps') flags.DEFINE_integer('iterations_per_loop', 500, 'iterations_per_loop') flags.DEFINE_integer('save_summary_steps', 500, 'save_summary_steps') flags.DEFINE_string('cell', 'lstm', 'cell') flags.DEFINE_float('learning_rate', 5e-5, 'learning_rate') flags.DEFINE_float('dropout_rate', 0.5, 'dropout_rate') flags.DEFINE_float('clip', 0.5, 'clip') flags.DEFINE_float('num_train_epochs', 10.0, 'num_train_epochs') flags.DEFINE_float("warmup_proportion", 0.1,'warmup_proportion') model_dir = r'output' bert_dir = 'chinese_L-12_H-768_A-12' is_training=False use_one_hot_embeddings=False batch_size=1 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True sess=tf.Session(config=gpu_config) model=None global graph input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint"))) if not os.path.exists(os.path.join(model_dir, "checkpoint")): raise Exception("failed to get checkpoint. going to return ") # 加载label->id的词典 with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf: label2id = pickle.load(rf) id2label = {value: key for key, value in label2id.items()} with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf: label_list = pickle.load(rf) num_labels = len(label_list) + 1 graph = tf.get_default_graph() with graph.as_default(): print("going to restore checkpoint") #sess.run(tf.global_variables_initializer()) input_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json')) (total_loss, logits, trans, pred_ids) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(model_dir)) tokenizer = tokenization.FullTokenizer( vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=FLAGS.do_lower_case) def predict_online(): """ do online prediction. each time make prediction for one instance. you can change to a batch if you want. :param line: a list. element is: [dummy_label,text_a,text_b] :return: """ def convert(line): feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p') input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length)) input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length)) segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length)) label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length)) return input_ids, input_mask, segment_ids, label_ids global graph with graph.as_default(): print(id2label) while True: print('input the test sentence:') sentence = str(input()) start = datetime.now() if len(sentence) < 2: print(sentence) continue sentence = tokenizer.tokenize(sentence) # print('your input is:{}'.format(sentence)) input_ids, input_mask, segment_ids, label_ids = convert(sentence) feed_dict = {input_ids_p: input_ids, input_mask_p: input_mask} # run session get current feed_dict result pred_ids_result = sess.run([pred_ids], feed_dict) pred_label_result = convert_id_to_label(pred_ids_result, id2label) print(pred_label_result) #todo: 组合策略 result = strage_combined_link_org_loc(sentence, pred_label_result[0]) print('time used: {} sec'.format((datetime.now() - start).total_seconds())) def convert_id_to_label(pred_ids_result, idx2label): """ 将id形式的结果转化为真实序列结果 :param pred_ids_result: :param idx2label: :return: """ result = [] for row in range(batch_size): curr_seq = [] for ids in pred_ids_result[row][0]: if ids == 0: break curr_label = idx2label[ids] if curr_label in ['[CLS]', '[SEP]']: continue curr_seq.append(curr_label) result.append(curr_seq) return result def strage_combined_link_org_loc(tokens, tags): """ 组合策略 :param pred_label_result: :param types: :return: """ def print_output(data, type): line = [] line.append(type) for i in data: line.append(i.word) print(', '.join(line)) params = None eval = Result(params) if len(tokens) > len(tags): tokens = tokens[:len(tags)] person, loc, org = eval.get_result(tokens, tags) print_output(loc, 'LOC') print_output(person, 'PER') print_output(org, 'ORG') def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): """ 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 :param ex_index: index :param example: 一个样本 :param label_list: 标签列表 :param max_seq_length: :param tokenizer: :param mode: :return: """ label_map = {} # 1表示从1开始对label进行index化 for (i, label) in enumerate(label_list, 1): label_map[label] = i # 保存label->index 的map if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')): with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w: pickle.dump(label_map, w) tokens = example # tokens = tokenizer.tokenize(example.text) # 序列截断 if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 ntokens = [] segment_ids = [] label_ids = [] ntokens.append("[CLS]") # 句子开始设置CLS 标志 segment_ids.append(0) # append("O") or append("[CLS]") not sure! label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) label_ids.append(0) ntokens.append("[SEP]") # 句尾添加[SEP] 标志 segment_ids.append(0) # append("O") or append("[SEP]") not sure! label_ids.append(label_map["[SEP]"]) input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 input_mask = [1] * len(input_ids) # padding, 使用 while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) # we don't concerned about it! label_ids.append(0) ntokens.append("**NULL**") # label_mask.append(0) # print(len(input_ids)) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length # assert len(label_mask) == max_seq_length # 结构化为一个类 feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, # label_mask = label_mask ) return feature class Pair(object): def __init__(self, word, start, end, type, merge=False): self.__word = word self.__start = start self.__end = end self.__merge = merge self.__types = type @property def start(self): return self.__start @property def end(self): return self.__end @property def merge(self): return self.__merge @property def word(self): return self.__word @property def types(self): return self.__types @word.setter def word(self, word): self.__word = word @start.setter def start(self, start): self.__start = start @end.setter def end(self, end): self.__end = end @merge.setter def merge(self, merge): self.__merge = merge @types.setter def types(self, type): self.__types = type def __str__(self) -> str: line = [] line.append('entity:{}'.format(self.__word)) line.append('start:{}'.format(self.__start)) line.append('end:{}'.format(self.__end)) line.append('merge:{}'.format(self.__merge)) line.append('types:{}'.format(self.__types)) return '\t'.join(line) class Result(object): def __init__(self, config): self.config = config self.person = [] self.loc = [] self.org = [] self.others = [] def get_result(self, tokens, tags, config=None): # 先获取标注结果 self.result_to_json(tokens, tags) return self.person, self.loc, self.org def result_to_json(self, string, tags): """ 将模型标注序列和输入序列结合 转化为结果 :param string: 输入序列 :param tags: 标注结果 :return: """ item = {"entities": []} entity_name = "" entity_start = 0 idx = 0 last_tag = '' for char, tag in zip(string, tags): if tag[0] == "S": self.append(char, idx, idx+1, tag[2:]) item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]}) elif tag[0] == "B": if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) entity_name = "" entity_name += char entity_start = idx elif tag[0] == "I": entity_name += char elif tag[0] == "O": if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) entity_name = "" else: entity_name = "" entity_start = idx idx += 1 last_tag = tag if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) return item def append(self, word, start, end, tag): if tag == 'LOC': self.loc.append(Pair(word, start, end, 'LOC')) elif tag == 'PER': self.person.append(Pair(word, start, end, 'PER')) elif tag == 'ORG': self.org.append(Pair(word, start, end, 'ORG')) else: self.others.append(Pair(word, start, end, tag)) if __name__ == "__main__": predict_online()