You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

terminal_predict.py 13 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import tensorflow as tf
  2. import numpy as np
  3. import codecs
  4. import pickle
  5. import os
  6. from datetime import datetime
  7. from bert_base.train.model_utlis import create_model, InputFeatures
  8. from bert_base.bert import tokenization, modeling
  9. flags = tf.flags
  10. FLAGS = flags.FLAGS
  11. #输入输出地址
  12. flags.DEFINE_string('data_dir', 'data', '数据集地址')
  13. flags.DEFINE_string('output_dir', 'output', '输出地址')
  14. #Bert相关参数
  15. flags.DEFINE_string('bert_config_file', 'chinese_L-12_H-768_A-12/bert_config.json', 'Bert配置文件')
  16. flags.DEFINE_string('vocab_file', 'chinese_L-12_H-768_A-12/vocab.txt','vocab_file')
  17. flags.DEFINE_string('init_checkpoint','chinese_L-12_H-768_A-12/bert_model.ckpt', 'init_checkpoint')
  18. #训练和校验的相关参数
  19. flags.DEFINE_bool('do_train', False, '是否开始训练')
  20. flags.DEFINE_bool('do_dev', False, '是否开始校验')
  21. flags.DEFINE_bool('do_test', True, '是否开始测试')
  22. flags.DEFINE_bool('do_lower_case', True, '是否转换小写')
  23. #模型相关的
  24. flags.DEFINE_integer('lstm_size', 128, 'lstm_size')
  25. flags.DEFINE_integer('num_layers', 1, 'num_layers')
  26. flags.DEFINE_integer('max_seq_length', 128, 'max_seq_length')
  27. flags.DEFINE_integer('train_batch_size', 64, 'train_batch_size')
  28. flags.DEFINE_integer('dev_batch_size',64, 'dev_batch_size')
  29. flags.DEFINE_integer('test_batch_size', 32, 'test_batch_size')
  30. flags.DEFINE_integer('save_checkpoints_steps', 500, 'save_checkpoints_steps')
  31. flags.DEFINE_integer('iterations_per_loop', 500, 'iterations_per_loop')
  32. flags.DEFINE_integer('save_summary_steps', 500, 'save_summary_steps')
  33. flags.DEFINE_string('cell', 'lstm', 'cell')
  34. flags.DEFINE_float('learning_rate', 5e-5, 'learning_rate')
  35. flags.DEFINE_float('dropout_rate', 0.5, 'dropout_rate')
  36. flags.DEFINE_float('clip', 0.5, 'clip')
  37. flags.DEFINE_float('num_train_epochs', 10.0, 'num_train_epochs')
  38. flags.DEFINE_float("warmup_proportion", 0.1,'warmup_proportion')
  39. model_dir = r'output'
  40. bert_dir = 'chinese_L-12_H-768_A-12'
  41. is_training=False
  42. use_one_hot_embeddings=False
  43. batch_size=1
  44. gpu_config = tf.ConfigProto()
  45. gpu_config.gpu_options.allow_growth = True
  46. sess=tf.Session(config=gpu_config)
  47. model=None
  48. global graph
  49. input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None
  50. print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
  51. if not os.path.exists(os.path.join(model_dir, "checkpoint")):
  52. raise Exception("failed to get checkpoint. going to return ")
  53. # 加载label->id的词典
  54. with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
  55. label2id = pickle.load(rf)
  56. id2label = {value: key for key, value in label2id.items()}
  57. with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
  58. label_list = pickle.load(rf)
  59. num_labels = len(label_list) + 1
  60. graph = tf.get_default_graph()
  61. with graph.as_default():
  62. print("going to restore checkpoint")
  63. #sess.run(tf.global_variables_initializer())
  64. input_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_ids")
  65. input_mask_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_mask")
  66. bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
  67. (total_loss, logits, trans, pred_ids) = create_model(
  68. bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None,
  69. labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)
  70. saver = tf.train.Saver()
  71. saver.restore(sess, tf.train.latest_checkpoint(model_dir))
  72. tokenizer = tokenization.FullTokenizer(
  73. vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=FLAGS.do_lower_case)
  74. def predict_online():
  75. """
  76. do online prediction. each time make prediction for one instance.
  77. you can change to a batch if you want.
  78. :param line: a list. element is: [dummy_label,text_a,text_b]
  79. :return:
  80. """
  81. def convert(line):
  82. feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p')
  83. input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length))
  84. input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length))
  85. segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length))
  86. label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length))
  87. return input_ids, input_mask, segment_ids, label_ids
  88. global graph
  89. with graph.as_default():
  90. print(id2label)
  91. while True:
  92. print('input the test sentence:')
  93. sentence = str(input())
  94. start = datetime.now()
  95. if len(sentence) < 2:
  96. print(sentence)
  97. continue
  98. sentence = tokenizer.tokenize(sentence)
  99. # print('your input is:{}'.format(sentence))
  100. input_ids, input_mask, segment_ids, label_ids = convert(sentence)
  101. feed_dict = {input_ids_p: input_ids,
  102. input_mask_p: input_mask}
  103. # run session get current feed_dict result
  104. pred_ids_result = sess.run([pred_ids], feed_dict)
  105. pred_label_result = convert_id_to_label(pred_ids_result, id2label)
  106. print(pred_label_result)
  107. #todo: 组合策略
  108. result = strage_combined_link_org_loc(sentence, pred_label_result[0])
  109. print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
  110. def convert_id_to_label(pred_ids_result, idx2label):
  111. """
  112. 将id形式的结果转化为真实序列结果
  113. :param pred_ids_result:
  114. :param idx2label:
  115. :return:
  116. """
  117. result = []
  118. for row in range(batch_size):
  119. curr_seq = []
  120. for ids in pred_ids_result[row][0]:
  121. if ids == 0:
  122. break
  123. curr_label = idx2label[ids]
  124. if curr_label in ['[CLS]', '[SEP]']:
  125. continue
  126. curr_seq.append(curr_label)
  127. result.append(curr_seq)
  128. return result
  129. def strage_combined_link_org_loc(tokens, tags):
  130. """
  131. 组合策略
  132. :param pred_label_result:
  133. :param types:
  134. :return:
  135. """
  136. def print_output(data, type):
  137. line = []
  138. line.append(type)
  139. for i in data:
  140. line.append(i.word)
  141. print(', '.join(line))
  142. params = None
  143. eval = Result(params)
  144. if len(tokens) > len(tags):
  145. tokens = tokens[:len(tags)]
  146. person, loc, org = eval.get_result(tokens, tags)
  147. print_output(loc, 'LOC')
  148. print_output(person, 'PER')
  149. print_output(org, 'ORG')
  150. def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode):
  151. """
  152. 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
  153. :param ex_index: index
  154. :param example: 一个样本
  155. :param label_list: 标签列表
  156. :param max_seq_length:
  157. :param tokenizer:
  158. :param mode:
  159. :return:
  160. """
  161. label_map = {}
  162. # 1表示从1开始对label进行index化
  163. for (i, label) in enumerate(label_list, 1):
  164. label_map[label] = i
  165. # 保存label->index 的map
  166. if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
  167. with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
  168. pickle.dump(label_map, w)
  169. tokens = example
  170. # tokens = tokenizer.tokenize(example.text)
  171. # 序列截断
  172. if len(tokens) >= max_seq_length - 1:
  173. tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
  174. ntokens = []
  175. segment_ids = []
  176. label_ids = []
  177. ntokens.append("[CLS]") # 句子开始设置CLS 标志
  178. segment_ids.append(0)
  179. # append("O") or append("[CLS]") not sure!
  180. label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
  181. for i, token in enumerate(tokens):
  182. ntokens.append(token)
  183. segment_ids.append(0)
  184. label_ids.append(0)
  185. ntokens.append("[SEP]") # 句尾添加[SEP] 标志
  186. segment_ids.append(0)
  187. # append("O") or append("[SEP]") not sure!
  188. label_ids.append(label_map["[SEP]"])
  189. input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式
  190. input_mask = [1] * len(input_ids)
  191. # padding, 使用
  192. while len(input_ids) < max_seq_length:
  193. input_ids.append(0)
  194. input_mask.append(0)
  195. segment_ids.append(0)
  196. # we don't concerned about it!
  197. label_ids.append(0)
  198. ntokens.append("**NULL**")
  199. # label_mask.append(0)
  200. # print(len(input_ids))
  201. assert len(input_ids) == max_seq_length
  202. assert len(input_mask) == max_seq_length
  203. assert len(segment_ids) == max_seq_length
  204. assert len(label_ids) == max_seq_length
  205. # assert len(label_mask) == max_seq_length
  206. # 结构化为一个类
  207. feature = InputFeatures(
  208. input_ids=input_ids,
  209. input_mask=input_mask,
  210. segment_ids=segment_ids,
  211. label_ids=label_ids,
  212. # label_mask = label_mask
  213. )
  214. return feature
  215. class Pair(object):
  216. def __init__(self, word, start, end, type, merge=False):
  217. self.__word = word
  218. self.__start = start
  219. self.__end = end
  220. self.__merge = merge
  221. self.__types = type
  222. @property
  223. def start(self):
  224. return self.__start
  225. @property
  226. def end(self):
  227. return self.__end
  228. @property
  229. def merge(self):
  230. return self.__merge
  231. @property
  232. def word(self):
  233. return self.__word
  234. @property
  235. def types(self):
  236. return self.__types
  237. @word.setter
  238. def word(self, word):
  239. self.__word = word
  240. @start.setter
  241. def start(self, start):
  242. self.__start = start
  243. @end.setter
  244. def end(self, end):
  245. self.__end = end
  246. @merge.setter
  247. def merge(self, merge):
  248. self.__merge = merge
  249. @types.setter
  250. def types(self, type):
  251. self.__types = type
  252. def __str__(self) -> str:
  253. line = []
  254. line.append('entity:{}'.format(self.__word))
  255. line.append('start:{}'.format(self.__start))
  256. line.append('end:{}'.format(self.__end))
  257. line.append('merge:{}'.format(self.__merge))
  258. line.append('types:{}'.format(self.__types))
  259. return '\t'.join(line)
  260. class Result(object):
  261. def __init__(self, config):
  262. self.config = config
  263. self.person = []
  264. self.loc = []
  265. self.org = []
  266. self.others = []
  267. def get_result(self, tokens, tags, config=None):
  268. # 先获取标注结果
  269. self.result_to_json(tokens, tags)
  270. return self.person, self.loc, self.org
  271. def result_to_json(self, string, tags):
  272. """
  273. 将模型标注序列和输入序列结合 转化为结果
  274. :param string: 输入序列
  275. :param tags: 标注结果
  276. :return:
  277. """
  278. item = {"entities": []}
  279. entity_name = ""
  280. entity_start = 0
  281. idx = 0
  282. last_tag = ''
  283. for char, tag in zip(string, tags):
  284. if tag[0] == "S":
  285. self.append(char, idx, idx+1, tag[2:])
  286. item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]})
  287. elif tag[0] == "B":
  288. if entity_name != '':
  289. self.append(entity_name, entity_start, idx, last_tag[2:])
  290. item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
  291. entity_name = ""
  292. entity_name += char
  293. entity_start = idx
  294. elif tag[0] == "I":
  295. entity_name += char
  296. elif tag[0] == "O":
  297. if entity_name != '':
  298. self.append(entity_name, entity_start, idx, last_tag[2:])
  299. item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
  300. entity_name = ""
  301. else:
  302. entity_name = ""
  303. entity_start = idx
  304. idx += 1
  305. last_tag = tag
  306. if entity_name != '':
  307. self.append(entity_name, entity_start, idx, last_tag[2:])
  308. item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
  309. return item
  310. def append(self, word, start, end, tag):
  311. if tag == 'LOC':
  312. self.loc.append(Pair(word, start, end, 'LOC'))
  313. elif tag == 'PER':
  314. self.person.append(Pair(word, start, end, 'PER'))
  315. elif tag == 'ORG':
  316. self.org.append(Pair(word, start, end, 'ORG'))
  317. else:
  318. self.others.append(Pair(word, start, end, tag))
  319. if __name__ == "__main__":
  320. predict_online()

通过对新冠疫情相关信息收集,进行分类、归纳,取得事件之间的联系,可以构成一个丰富的新冠信息知识图谱。新冠信息知识图谱的构建能够充分挖掘信息价值,为人们提供直观的参考依据。本项目基于NEO4J图数据库,来进行COVID-19病例活动行径信息的知识图谱构建与应用,达到追溯传播途径、疫情防控的目的·。