You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LstmCrfNer.cs 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. using Tensorflow.Estimator;
  9. using TensorFlowNET.Examples.Utility;
  10. using static Tensorflow.Python;
  11. using static TensorFlowNET.Examples.DataHelpers;
  12. namespace TensorFlowNET.Examples.Text.NER
  13. {
  14. /// <summary>
  15. /// A NER model using Tensorflow (LSTM + CRF + chars embeddings).
  16. /// State-of-the-art performance (F1 score between 90 and 91).
  17. ///
  18. /// https://github.com/guillaumegenthial/sequence_tagging
  19. /// </summary>
  20. public class LstmCrfNer : IExample
  21. {
  22. public int Priority => 14;
  23. public bool Enabled { get; set; } = true;
  24. public bool ImportGraph { get; set; } = true;
  25. public string Name => "LSTM + CRF NER";
  26. HyperParams hp;
  27. int nwords, nchars, ntags;
  28. CoNLLDataset dev, train;
  29. Tensor word_ids_tensor;
  30. Tensor sequence_lengths_tensor;
  31. Tensor char_ids_tensor;
  32. Tensor word_lengths_tensor;
  33. Tensor labels_tensor;
  34. Tensor dropout_tensor;
  35. Tensor lr_tensor;
  36. Operation train_op;
  37. Tensor loss;
  38. Tensor merged;
  39. public bool Run()
  40. {
  41. PrepareData();
  42. var graph = tf.Graph().as_default();
  43. tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
  44. float loss_value = 0f;
  45. //add_summary();
  46. word_ids_tensor = graph.OperationByName("word_ids");
  47. sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
  48. char_ids_tensor = graph.OperationByName("char_ids");
  49. word_lengths_tensor = graph.OperationByName("word_lengths");
  50. labels_tensor = graph.OperationByName("labels");
  51. dropout_tensor = graph.OperationByName("dropout");
  52. lr_tensor = graph.OperationByName("lr");
  53. train_op = graph.OperationByName("train_step/Adam");
  54. loss = graph.OperationByName("Mean");
  55. //merged = graph.OperationByName("Merge/MergeSummary");
  56. var init = tf.global_variables_initializer();
  57. with(tf.Session(), sess =>
  58. {
  59. sess.run(init);
  60. foreach (var epoch in range(hp.epochs))
  61. {
  62. Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, ");
  63. loss_value = run_epoch(sess, train, dev, epoch);
  64. print($"train loss: {loss_value}");
  65. }
  66. });
  67. return loss_value < 0.1;
  68. }
  69. private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch)
  70. {
  71. NDArray results = null;
  72. // iterate over dataset
  73. var batches = minibatches(train, hp.batch_size);
  74. foreach (var(words, labels) in batches)
  75. {
  76. var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout);
  77. results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd);
  78. }
  79. return results[1];
  80. }
  81. private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
  82. {
  83. var x_batch = new List<(int[][], int[])>();
  84. var y_batch = new List<int[]>();
  85. foreach(var (x, y) in data.GetItems())
  86. {
  87. if (len(y_batch) == minibatch_size)
  88. {
  89. yield return (x_batch.ToArray(), y_batch.ToArray());
  90. x_batch.Clear();
  91. y_batch.Clear();
  92. }
  93. var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray());
  94. x_batch.Add(x3);
  95. y_batch.Add(y);
  96. }
  97. if (len(y_batch) > 0)
  98. yield return (x_batch.ToArray(), y_batch.ToArray());
  99. }
  100. /// <summary>
  101. /// Given some data, pad it and build a feed dictionary
  102. /// </summary>
  103. /// <param name="words">
  104. /// list of sentences. A sentence is a list of ids of a list of
  105. /// words. A word is a list of ids
  106. /// </param>
  107. /// <param name="labels">list of ids</param>
  108. /// <param name="lr">learning rate</param>
  109. /// <param name="dropout">keep prob</param>
  110. private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
  111. {
  112. int[] sequence_lengths;
  113. int[][] word_lengths;
  114. int[][] word_ids;
  115. int[][][] char_ids;
  116. if (true) // use_chars
  117. {
  118. (char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray());
  119. (word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0);
  120. (char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0);
  121. }
  122. // build feed dictionary
  123. var feeds = new List<FeedItem>();
  124. feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids)));
  125. feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths)));
  126. if(true) // use_chars
  127. {
  128. feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids)));
  129. feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
  130. }
  131. (labels, _) = pad_sequences(labels, 0);
  132. feeds.Add(new FeedItem(labels_tensor, np.array(labels)));
  133. feeds.Add(new FeedItem(lr_tensor, lr));
  134. feeds.Add(new FeedItem(dropout_tensor, dropout));
  135. return (feeds.ToArray(), sequence_lengths);
  136. }
  137. public void PrepareData()
  138. {
  139. hp = new HyperParams("LstmCrfNer")
  140. {
  141. epochs = 50,
  142. dropout = 0.5f,
  143. batch_size = 20,
  144. lr_method = "adam",
  145. lr = 0.001f,
  146. lr_decay = 0.9f,
  147. clip = false,
  148. epoch_no_imprv = 3,
  149. hidden_size_char = 100,
  150. hidden_size_lstm = 300
  151. };
  152. hp.filepath_dev = hp.filepath_test = hp.filepath_train = Path.Combine(hp.data_root_dir, "test.txt");
  153. // Loads vocabulary, processing functions and embeddings
  154. hp.filepath_words = Path.Combine(hp.data_root_dir, "words.txt");
  155. hp.filepath_tags = Path.Combine(hp.data_root_dir, "tags.txt");
  156. hp.filepath_chars = Path.Combine(hp.data_root_dir, "chars.txt");
  157. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/lstm_crf_ner.zip";
  158. Web.Download(url, hp.data_root_dir, "lstm_crf_ner.zip");
  159. Compress.UnZip(Path.Combine(hp.data_root_dir, "lstm_crf_ner.zip"), hp.data_root_dir);
  160. // 1. vocabulary
  161. /*vocab_tags = load_vocab(hp.filepath_tags);
  162. nwords = vocab_words.Count;
  163. nchars = vocab_chars.Count;
  164. ntags = vocab_tags.Count;*/
  165. // 2. get processing functions that map str -> id
  166. dev = new CoNLLDataset(hp.filepath_dev, hp);
  167. train = new CoNLLDataset(hp.filepath_train, hp);
  168. // download graph meta data
  169. var meta_file = "lstm_crf_ner.meta";
  170. var meta_path = Path.Combine("graph", meta_file);
  171. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
  172. Web.Download(url, "graph", meta_file);
  173. }
  174. }
  175. }