You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LstmCrfNer.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. using Tensorflow.Estimator;
  9. using TensorFlowNET.Examples.Utility;
  10. using static Tensorflow.Python;
  11. using static TensorFlowNET.Examples.DataHelpers;
  12. namespace TensorFlowNET.Examples.Text.NER
  13. {
  14. /// <summary>
  15. /// A NER model using Tensorflow (LSTM + CRF + chars embeddings).
  16. /// State-of-the-art performance (F1 score between 90 and 91).
  17. ///
  18. /// https://github.com/guillaumegenthial/sequence_tagging
  19. /// </summary>
  20. public class LstmCrfNer : IExample
  21. {
  22. public int Priority => 14;
  23. public bool Enabled { get; set; } = true;
  24. public bool ImportGraph { get; set; } = true;
  25. public string Name => "LSTM + CRF NER";
  26. HyperParams hp;
  27. int nwords, nchars, ntags;
  28. CoNLLDataset dev, train;
  29. Tensor word_ids_tensor;
  30. Tensor sequence_lengths_tensor;
  31. Tensor char_ids_tensor;
  32. Tensor word_lengths_tensor;
  33. Tensor labels_tensor;
  34. Tensor dropout_tensor;
  35. Tensor lr_tensor;
  36. public bool Run()
  37. {
  38. PrepareData();
  39. var graph = tf.Graph().as_default();
  40. tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
  41. word_ids_tensor = graph.OperationByName("word_ids");
  42. sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
  43. char_ids_tensor = graph.OperationByName("char_ids");
  44. word_lengths_tensor = graph.OperationByName("word_lengths");
  45. labels_tensor = graph.OperationByName("labels");
  46. dropout_tensor = graph.OperationByName("dropout");
  47. lr_tensor = graph.OperationByName("lr");
  48. var init = tf.global_variables_initializer();
  49. with(tf.Session(), sess =>
  50. {
  51. sess.run(init);
  52. foreach (var epoch in range(hp.epochs))
  53. {
  54. print($"Epoch {epoch + 1} out of {hp.epochs}");
  55. run_epoch(train, dev, epoch);
  56. }
  57. });
  58. return true;
  59. }
  60. private void run_epoch(CoNLLDataset train, CoNLLDataset dev, int epoch)
  61. {
  62. int i = 0;
  63. // iterate over dataset
  64. var batches = minibatches(train, hp.batch_size);
  65. foreach (var(words, labels) in batches)
  66. {
  67. get_feed_dict(words, labels, hp.lr, hp.dropout);
  68. }
  69. }
  70. private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
  71. {
  72. var x_batch = new List<(int[][], int[])>();
  73. var y_batch = new List<int[]>();
  74. foreach(var (x, y) in data.GetItems())
  75. {
  76. if (len(y_batch) == minibatch_size)
  77. {
  78. yield return (x_batch.ToArray(), y_batch.ToArray());
  79. x_batch.Clear();
  80. y_batch.Clear();
  81. }
  82. var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray());
  83. x_batch.Add(x3);
  84. y_batch.Add(y);
  85. }
  86. if (len(y_batch) > 0)
  87. yield return (x_batch.ToArray(), y_batch.ToArray());
  88. }
  89. /// <summary>
  90. /// Given some data, pad it and build a feed dictionary
  91. /// </summary>
  92. /// <param name="words">
  93. /// list of sentences. A sentence is a list of ids of a list of
  94. /// words. A word is a list of ids
  95. /// </param>
  96. /// <param name="labels">list of ids</param>
  97. /// <param name="lr">learning rate</param>
  98. /// <param name="dropout">keep prob</param>
  99. private FeedItem[] get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
  100. {
  101. int[] sequence_lengths;
  102. int[][] word_lengths;
  103. int[][] word_ids;
  104. int[][][] char_ids;
  105. if (true) // use_chars
  106. {
  107. (char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray());
  108. (word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0);
  109. (char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0);
  110. }
  111. // build feed dictionary
  112. var feeds = new List<FeedItem>();
  113. feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids)));
  114. feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths)));
  115. if(true) // use_chars
  116. {
  117. feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids)));
  118. feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
  119. }
  120. throw new NotImplementedException("get_feed_dict");
  121. }
  122. public void PrepareData()
  123. {
  124. hp = new HyperParams("LstmCrfNer")
  125. {
  126. epochs = 15,
  127. dropout = 0.5f,
  128. batch_size = 20,
  129. lr_method = "adam",
  130. lr = 0.001f,
  131. lr_decay = 0.9f,
  132. clip = false,
  133. epoch_no_imprv = 3,
  134. hidden_size_char = 100,
  135. hidden_size_lstm = 300
  136. };
  137. hp.filepath_dev = hp.filepath_test = hp.filepath_train = Path.Combine(hp.data_root_dir, "test.txt");
  138. // Loads vocabulary, processing functions and embeddings
  139. hp.filepath_words = Path.Combine(hp.data_root_dir, "words.txt");
  140. hp.filepath_tags = Path.Combine(hp.data_root_dir, "tags.txt");
  141. hp.filepath_chars = Path.Combine(hp.data_root_dir, "chars.txt");
  142. // 1. vocabulary
  143. /*vocab_tags = load_vocab(hp.filepath_tags);
  144. nwords = vocab_words.Count;
  145. nchars = vocab_chars.Count;
  146. ntags = vocab_tags.Count;*/
  147. // 2. get processing functions that map str -> id
  148. dev = new CoNLLDataset(hp.filepath_dev, hp);
  149. train = new CoNLLDataset(hp.filepath_train, hp);
  150. }
  151. }
  152. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。