You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LstmCrfNer.cs 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. using Tensorflow.Estimator;
  9. using TensorFlowNET.Examples.Utility;
  10. using static Tensorflow.Python;
  11. using static TensorFlowNET.Examples.DataHelpers;
  12. namespace TensorFlowNET.Examples.Text.NER
  13. {
  14. /// <summary>
  15. /// A NER model using Tensorflow (LSTM + CRF + chars embeddings).
  16. /// State-of-the-art performance (F1 score between 90 and 91).
  17. ///
  18. /// https://github.com/guillaumegenthial/sequence_tagging
  19. /// </summary>
  20. public class LstmCrfNer : IExample
  21. {
  22. public bool Enabled { get; set; } = true;
  23. public bool IsImportingGraph { get; set; } = true;
  24. public string Name => "LSTM + CRF NER";
  25. HyperParams hp;
  26. int nwords, nchars, ntags;
  27. CoNLLDataset dev, train;
  28. Tensor word_ids_tensor;
  29. Tensor sequence_lengths_tensor;
  30. Tensor char_ids_tensor;
  31. Tensor word_lengths_tensor;
  32. Tensor labels_tensor;
  33. Tensor dropout_tensor;
  34. Tensor lr_tensor;
  35. Operation train_op;
  36. Tensor loss;
  37. Tensor merged;
  38. public bool Run()
  39. {
  40. PrepareData();
  41. var graph = tf.Graph().as_default();
  42. tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
  43. float loss_value = 0f;
  44. //add_summary();
  45. word_ids_tensor = graph.OperationByName("word_ids");
  46. sequence_lengths_tensor = graph.OperationByName("sequence_lengths");
  47. char_ids_tensor = graph.OperationByName("char_ids");
  48. word_lengths_tensor = graph.OperationByName("word_lengths");
  49. labels_tensor = graph.OperationByName("labels");
  50. dropout_tensor = graph.OperationByName("dropout");
  51. lr_tensor = graph.OperationByName("lr");
  52. train_op = graph.OperationByName("train_step/Adam");
  53. loss = graph.OperationByName("Mean");
  54. //merged = graph.OperationByName("Merge/MergeSummary");
  55. var init = tf.global_variables_initializer();
  56. with(tf.Session(), sess =>
  57. {
  58. sess.run(init);
  59. foreach (var epoch in range(hp.epochs))
  60. {
  61. Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, ");
  62. loss_value = run_epoch(sess, train, dev, epoch);
  63. print($"train loss: {loss_value}");
  64. }
  65. });
  66. return loss_value < 0.1;
  67. }
  68. private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch)
  69. {
  70. NDArray results = null;
  71. // iterate over dataset
  72. var batches = minibatches(train, hp.batch_size);
  73. foreach (var(words, labels) in batches)
  74. {
  75. var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout);
  76. results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd);
  77. }
  78. return results[1];
  79. }
  80. private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size)
  81. {
  82. var x_batch = new List<(int[][], int[])>();
  83. var y_batch = new List<int[]>();
  84. foreach(var (x, y) in data.GetItems())
  85. {
  86. if (len(y_batch) == minibatch_size)
  87. {
  88. yield return (x_batch.ToArray(), y_batch.ToArray());
  89. x_batch.Clear();
  90. y_batch.Clear();
  91. }
  92. var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray());
  93. x_batch.Add(x3);
  94. y_batch.Add(y);
  95. }
  96. if (len(y_batch) > 0)
  97. yield return (x_batch.ToArray(), y_batch.ToArray());
  98. }
  99. /// <summary>
  100. /// Given some data, pad it and build a feed dictionary
  101. /// </summary>
  102. /// <param name="words">
  103. /// list of sentences. A sentence is a list of ids of a list of
  104. /// words. A word is a list of ids
  105. /// </param>
  106. /// <param name="labels">list of ids</param>
  107. /// <param name="lr">learning rate</param>
  108. /// <param name="dropout">keep prob</param>
  109. private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f)
  110. {
  111. int[] sequence_lengths;
  112. int[][] word_lengths;
  113. int[][] word_ids;
  114. int[][][] char_ids;
  115. if (true) // use_chars
  116. {
  117. (char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray());
  118. (word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0);
  119. (char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0);
  120. }
  121. // build feed dictionary
  122. var feeds = new List<FeedItem>();
  123. feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids)));
  124. feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths)));
  125. if(true) // use_chars
  126. {
  127. feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids)));
  128. feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths)));
  129. }
  130. (labels, _) = pad_sequences(labels, 0);
  131. feeds.Add(new FeedItem(labels_tensor, np.array(labels)));
  132. feeds.Add(new FeedItem(lr_tensor, lr));
  133. feeds.Add(new FeedItem(dropout_tensor, dropout));
  134. return (feeds.ToArray(), sequence_lengths);
  135. }
  136. public void PrepareData()
  137. {
  138. hp = new HyperParams("LstmCrfNer")
  139. {
  140. epochs = 50,
  141. dropout = 0.5f,
  142. batch_size = 20,
  143. lr_method = "adam",
  144. lr = 0.001f,
  145. lr_decay = 0.9f,
  146. clip = false,
  147. epoch_no_imprv = 3,
  148. hidden_size_char = 100,
  149. hidden_size_lstm = 300
  150. };
  151. hp.filepath_dev = hp.filepath_test = hp.filepath_train = Path.Combine(hp.data_root_dir, "test.txt");
  152. // Loads vocabulary, processing functions and embeddings
  153. hp.filepath_words = Path.Combine(hp.data_root_dir, "words.txt");
  154. hp.filepath_tags = Path.Combine(hp.data_root_dir, "tags.txt");
  155. hp.filepath_chars = Path.Combine(hp.data_root_dir, "chars.txt");
  156. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/lstm_crf_ner.zip";
  157. Web.Download(url, hp.data_root_dir, "lstm_crf_ner.zip");
  158. Compress.UnZip(Path.Combine(hp.data_root_dir, "lstm_crf_ner.zip"), hp.data_root_dir);
  159. // 1. vocabulary
  160. /*vocab_tags = load_vocab(hp.filepath_tags);
  161. nwords = vocab_words.Count;
  162. nchars = vocab_chars.Count;
  163. ntags = vocab_tags.Count;*/
  164. // 2. get processing functions that map str -> id
  165. dev = new CoNLLDataset(hp.filepath_dev, hp);
  166. train = new CoNLLDataset(hp.filepath_train, hp);
  167. // download graph meta data
  168. var meta_file = "lstm_crf_ner.meta";
  169. var meta_path = Path.Combine("graph", meta_file);
  170. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
  171. Web.Download(url, "graph", meta_file);
  172. }
  173. public Graph ImportGraph()
  174. {
  175. throw new NotImplementedException();
  176. }
  177. public Graph BuildGraph()
  178. {
  179. throw new NotImplementedException();
  180. }
  181. public void Train(Session sess)
  182. {
  183. throw new NotImplementedException();
  184. }
  185. public void Predict(Session sess)
  186. {
  187. throw new NotImplementedException();
  188. }
  189. public void Test(Session sess)
  190. {
  191. throw new NotImplementedException();
  192. }
  193. }
  194. }