You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bilstm_crf.py 2.7 kB

6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #!/usr/bin/env python3
  2. # -*-coding:utf-8-*-
  3. """
  4. * Copyright (C) 2018 OwnThink.
  5. *
  6. * Name : bilstm_crf.py - 预测
  7. * Author : Yener <yener@ownthink.com>
  8. * Version : 0.01
  9. * Description :
  10. """
  11. import os
  12. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  13. import pickle
  14. import numpy as np
  15. import tensorflow as tf
  16. from tensorflow.contrib.crf import viterbi_decode
  17. class Predict(object):
  18. def __init__(self, model_file):
  19. with open(model_file, 'rb') as f:
  20. model, char_to_id, id_to_tag = pickle.load(f)
  21. self.char_to_id = char_to_id
  22. self.id_to_tag = {int(k): v for k, v in id_to_tag.items()}
  23. self.num_class = len(self.id_to_tag)
  24. graph_def = tf.GraphDef()
  25. graph_def.ParseFromString(model)
  26. with tf.Graph().as_default() as graph:
  27. tf.import_graph_def(graph_def, name="prefix")
  28. self.input_x = graph.get_tensor_by_name("prefix/char_inputs:0")
  29. self.lengths = graph.get_tensor_by_name("prefix/lengths:0")
  30. self.dropout = graph.get_tensor_by_name("prefix/dropout:0")
  31. self.logits = graph.get_tensor_by_name("prefix/project/logits:0")
  32. self.trans = graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
  33. self.sess = tf.Session(graph=graph)
  34. self.sess.as_default()
  35. def decode(self, logits, trans, sequence_lengths, tag_num):
  36. small = -1000.0
  37. viterbi_sequences = []
  38. start = np.asarray([[small] * tag_num + [0]])
  39. for logit, length in zip(logits, sequence_lengths):
  40. score = logit[:length]
  41. pad = small * np.ones([length, 1])
  42. score = np.concatenate([score, pad], axis=1)
  43. score = np.concatenate([start, score], axis=0)
  44. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  45. viterbi_sequences.append(viterbi_seq[1:])
  46. return viterbi_sequences
  47. def predict(self, sents):
  48. inputs = []
  49. lengths = [len(text) for text in sents]
  50. max_len = max(lengths)
  51. for sent in sents:
  52. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  53. padding = [0] * (max_len - len(sent_ids))
  54. sent_ids += padding
  55. inputs.append(sent_ids)
  56. inputs = np.array(inputs, dtype=np.int32)
  57. feed_dict = {
  58. self.input_x: inputs,
  59. self.lengths: lengths,
  60. self.dropout: 1.0
  61. }
  62. logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict)
  63. path = self.decode(logits, trans, lengths, self.num_class)
  64. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  65. return labels

Jiagu使用大规模语料训练而成。将提供中文分词、词性标注、命名实体识别、情感分析、知识图谱关系抽取、关键词抽取、文本摘要、新词发现、情感分析、文本聚类等常用自然语言处理功能。参考了各大工具优缺点制作,将Jiagu回馈给大家