You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_utils.py 4.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. import json
  3. import logging
  4. from collections import OrderedDict
  5. from conlleval import return_report
  6. import codecs
  7. import tensorflow as tf
  8. def get_logger(log_file):
  9. """
  10. 定义日志方法
  11. :param log_file:
  12. :return:
  13. """
  14. # 创建一个logging的实例 logger
  15. logger = logging.getLogger(log_file)
  16. # 设置logger的全局日志级别为DEBUG
  17. logger.setLevel(logging.DEBUG)
  18. # 创建一个日志文件的handler,并且设置日志级别为DEBUG
  19. fh = logging.FileHandler(log_file)
  20. fh.setLevel(logging.DEBUG)
  21. # 创建一个控制台的handler,并设置日志级别为DEBUG
  22. ch = logging.StreamHandler()
  23. ch.setLevel(logging.INFO)
  24. # 设置日志格式
  25. formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  26. # add formatter to ch and fh
  27. ch.setFormatter(formatter)
  28. fh.setFormatter(formatter)
  29. # add ch and fh to logger
  30. logger.addHandler(ch)
  31. logger.addHandler(fh)
  32. return logger
  33. def config_model(FLAGS, word_to_id, tag_to_id):
  34. config = OrderedDict()
  35. config['num_words'] = len(word_to_id)
  36. config['word_dim'] = FLAGS.word_dim
  37. config['num_tags'] = len(tag_to_id)
  38. config['seg_dim'] = FLAGS.seg_dim
  39. config['lstm_dim'] = FLAGS.lstm_dim
  40. config['batch_size'] = FLAGS.batch_size
  41. config['optimizer'] = FLAGS.optimizer
  42. config['emb_file'] = FLAGS.emb_file
  43. config['clip'] = FLAGS.clip
  44. config['dropout_keep'] = 1.0 - FLAGS.dropout
  45. config['optimizer'] = FLAGS.optimizer
  46. config['lr'] = FLAGS.lr
  47. config['tag_schema'] = FLAGS.tag_schema
  48. config['pre_emb'] = FLAGS.pre_emb
  49. config['model_type'] = FLAGS.model_type
  50. config['is_train'] = FLAGS.train
  51. return config
  52. def make_path(params):
  53. """
  54. 创建文件夹
  55. :param params:
  56. :return:
  57. """
  58. if not os.path.isdir(params.result_path):
  59. os.makedirs(params.result_path)
  60. if not os.path.isdir(params.ckpt_path):
  61. os.makedirs(params.ckpt_path)
  62. if not os.path.isdir('log'):
  63. os.makedirs('log')
  64. def save_config(config, config_file):
  65. """
  66. 保存配置文件
  67. :param config:
  68. :param config_path:
  69. :return:
  70. """
  71. with open(config_file, 'w', encoding='utf-8') as f:
  72. json.dump(config, f, ensure_ascii=False, indent=4)
  73. def load_config(config_file):
  74. """
  75. 加载配置文件
  76. :param config_file:
  77. :return:
  78. """
  79. with open(config_file, encoding='utf-8') as f:
  80. return json.load(f)
  81. def print_config(config, logger):
  82. """
  83. 打印模型参数
  84. :param config:
  85. :param logger:
  86. :return:
  87. """
  88. for k, v in config.items():
  89. logger.info("{}:\t{}".format(k.ljust(15), v))
  90. def create(sess, Model, ckpt_path, load_word2vec, config, id_to_word, logger):
  91. """
  92. :param sess:
  93. :param Model:
  94. :param ckpt_path:
  95. :param load_word2vec:
  96. :param config:
  97. :param id_to_word:
  98. :param logger:
  99. :return:
  100. """
  101. model = Model(config)
  102. ckpt = tf.train.get_checkpoint_state(ckpt_path)
  103. if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  104. logger("读取模型参数,从%s" % ckpt.model_checkpoint_path)
  105. model.saver.restore(sess, ckpt.model_checkpoint_path)
  106. else:
  107. logger.info("重新训练模型")
  108. sess.run(tf.global_variables_initializer())
  109. if config['pre_emb']:
  110. emb_weights = sess.run(model.word_lookup.read_value())
  111. emb_weights = load_word2vec(config['emb_file'], id_to_word, config['word_dim'], emb_weights)
  112. sess.run(model.word_lookup.assign(emb_weights))
  113. logger.info("加载词向量成功")
  114. return model
  115. def test_ner(results, path):
  116. """
  117. :param results:
  118. :param path:
  119. :return:
  120. """
  121. output_file = os.path.join(path, 'ner_predict.utf8')
  122. with codecs.open(output_file, "w", encoding="utf-8") as f_write:
  123. to_write = []
  124. for line in results:
  125. for iner_line in line:
  126. to_write.append(iner_line + "\n")
  127. to_write.append("\n")
  128. f_write.writelines(to_write)
  129. eval_lines = return_report(output_file)
  130. return eval_lines
  131. def save_model(sess, model, path, logger):
  132. """
  133. :param sess:
  134. :param model:
  135. :param path:
  136. :param logger:
  137. :return:
  138. """
  139. checkpoint_path = os.path.join(path, "ner.ckpt")
  140. model.saver.save(sess, checkpoint_path)
  141. logger.info('模型已经保存')

通过对新冠疫情相关信息收集,进行分类、归纳,取得事件之间的联系,可以构成一个丰富的新冠信息知识图谱。新冠信息知识图谱的构建能够充分挖掘信息价值,为人们提供直观的参考依据。本项目基于NEO4J图数据库,来进行COVID-19病例活动行径信息的知识图谱构建与应用,达到追溯传播途径、疫情防控的目的·。