You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LstmCrfNer.cs 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using Tensorflow;
  7. using Tensorflow.Estimator;
  8. using TensorFlowNET.Examples.Utility;
  9. using static Tensorflow.Python;
  10. using static TensorFlowNET.Examples.DataHelpers;
  11. namespace TensorFlowNET.Examples.Text.NER
  12. {
  13. /// <summary>
  14. /// A NER model using Tensorflow (LSTM + CRF + chars embeddings).
  15. /// State-of-the-art performance (F1 score between 90 and 91).
  16. ///
  17. /// https://github.com/guillaumegenthial/sequence_tagging
  18. /// </summary>
  19. public class LstmCrfNer : IExample
  20. {
  21. public bool Enabled { get; set; } = true;
  22. public bool IsImportingGraph { get; set; } = true;
  23. public string Name => "LSTM + CRF NER";
  24. HyperParams hp;
  25. int nwords, nchars, ntags;
  26. CoNLLDataset dev, train;
  27. Tensor word_ids_tensor;
  28. Tensor sequence_lengths_tensor;
  29. Tensor char_ids_tensor;
  30. Tensor word_lengths_tensor;
  31. Tensor labels_tensor;
  32. Tensor dropout_tensor;
  33. Tensor lr_tensor;
  34. Operation train_op;
  35. Tensor loss;
  36. Tensor merged;
  37. public bool Run()
  38. {
  39. PrepareData();
  40. var graph = tf.Graph().as_default();
  41. tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
  42. float loss_value = 0f;
  43. //add_summary();
  44. word_ids_tensor = graph.OperationByName("word_ids");
  45. sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
  46. char_ids_tensor = graph.OperationByName("char_ids");
  47. word_lengths_tensor = graph.OperationByName("word_lengths");
  48. labels_tensor = graph.OperationByName("labels");
  49. dropout_tensor = graph.OperationByName("dropout");
  50. lr_tensor = graph.OperationByName("lr");
  51. train_op = graph.OperationByName("train_step/Adam");
  52. loss = graph.OperationByName("Mean");
  53. //merged = graph.OperationByName("Merge/MergeSummary");
  54. var init = tf.global_variables_initializer();
  55. using (var sess = tf.Session())
  56. {
  57. sess.run(init);
  58. foreach (var epoch in range(hp.epochs))
  59. {
  60. Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, ");
  61. loss_value = run_epoch(sess, train, dev, epoch);
  62. print($"train loss: {loss_value}");
  63. }
  64. }
  65. return loss_value < 0.1;
  66. }
  67. private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch)
  68. {
  69. NDArray results = null;
  70. // iterate over dataset
  71. var batches = minibatches(train, hp.batch_size);
  72. foreach (var(words, labels) in batches)
  73. {
  74. var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout);
  75. results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd);
  76. }
  77. return results[1];
  78. }
  79. private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
  80. {
  81. var x_batch = new List<(int[][], int[])>();
  82. var y_batch = new List<int[]>();
  83. foreach(var (x, y) in data.GetItems())
  84. {
  85. if (len(y_batch) == minibatch_size)
  86. {
  87. yield return (x_batch.ToArray(), y_batch.ToArray());
  88. x_batch.Clear();
  89. y_batch.Clear();
  90. }
  91. var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray());
  92. x_batch.Add(x3);
  93. y_batch.Add(y);
  94. }
  95. if (len(y_batch) > 0)
  96. yield return (x_batch.ToArray(), y_batch.ToArray());
  97. }
  98. /// <summary>
  99. /// Given some data, pad it and build a feed dictionary
  100. /// </summary>
  101. /// <param name="words">
  102. /// list of sentences. A sentence is a list of ids of a list of
  103. /// words. A word is a list of ids
  104. /// </param>
  105. /// <param name="labels">list of ids</param>
  106. /// <param name="lr">learning rate</param>
  107. /// <param name="dropout">keep prob</param>
  108. private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
  109. {
  110. int[] sequence_lengths;
  111. int[][] word_lengths;
  112. int[][] word_ids;
  113. int[][][] char_ids;
  114. if (true) // use_chars
  115. {
  116. (char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray());
  117. (word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0);
  118. (char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0);
  119. }
  120. // build feed dictionary
  121. var feeds = new List<FeedItem>();
  122. feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids)));
  123. feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths)));
  124. if(true) // use_chars
  125. {
  126. feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids)));
  127. feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
  128. }
  129. (labels, _) = pad_sequences(labels, 0);
  130. feeds.Add(new FeedItem(labels_tensor, np.array(labels)));
  131. feeds.Add(new FeedItem(lr_tensor, lr));
  132. feeds.Add(new FeedItem(dropout_tensor, dropout));
  133. return (feeds.ToArray(), sequence_lengths);
  134. }
  135. public void PrepareData()
  136. {
  137. hp = new HyperParams("LstmCrfNer")
  138. {
  139. epochs = 50,
  140. dropout = 0.5f,
  141. batch_size = 20,
  142. lr_method = "adam",
  143. lr = 0.001f,
  144. lr_decay = 0.9f,
  145. clip = false,
  146. epoch_no_imprv = 3,
  147. hidden_size_char = 100,
  148. hidden_size_lstm = 300
  149. };
  150. hp.filepath_dev = hp.filepath_test = hp.filepath_train = Path.Combine(hp.data_root_dir, "test.txt");
  151. // Loads vocabulary, processing functions and embeddings
  152. hp.filepath_words = Path.Combine(hp.data_root_dir, "words.txt");
  153. hp.filepath_tags = Path.Combine(hp.data_root_dir, "tags.txt");
  154. hp.filepath_chars = Path.Combine(hp.data_root_dir, "chars.txt");
  155. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/lstm_crf_ner.zip";
  156. Web.Download(url, hp.data_root_dir, "lstm_crf_ner.zip");
  157. Compress.UnZip(Path.Combine(hp.data_root_dir, "lstm_crf_ner.zip"), hp.data_root_dir);
  158. // 1. vocabulary
  159. /*vocab_tags = load_vocab(hp.filepath_tags);
  160. nwords = vocab_words.Count;
  161. nchars = vocab_chars.Count;
  162. ntags = vocab_tags.Count;*/
  163. // 2. get processing functions that map str -> id
  164. dev = new CoNLLDataset(hp.filepath_dev, hp);
  165. train = new CoNLLDataset(hp.filepath_train, hp);
  166. // download graph meta data
  167. var meta_file = "lstm_crf_ner.meta";
  168. var meta_path = Path.Combine("graph", meta_file);
  169. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
  170. Web.Download(url, "graph", meta_file);
  171. }
  172. public Graph ImportGraph()
  173. {
  174. throw new NotImplementedException();
  175. }
  176. public Graph BuildGraph()
  177. {
  178. throw new NotImplementedException();
  179. }
  180. public void Train(Session sess)
  181. {
  182. throw new NotImplementedException();
  183. }
  184. public void Predict(Session sess)
  185. {
  186. throw new NotImplementedException();
  187. }
  188. public void Test(Session sess)
  189. {
  190. throw new NotImplementedException();
  191. }
  192. }
  193. }