You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LstmCrfNer.cs 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.Estimator;
  8. using TensorFlowNET.Examples.Utility;
  9. using static Tensorflow.Python;
  10. namespace TensorFlowNET.Examples.Text.NER
  11. {
  12. /// <summary>
  13. /// A NER model using Tensorflow (LSTM + CRF + chars embeddings).
  14. /// State-of-the-art performance (F1 score between 90 and 91).
  15. ///
  16. /// https://github.com/guillaumegenthial/sequence_tagging
  17. /// </summary>
  18. public class LstmCrfNer : IExample
  19. {
  20. public int Priority => 14;
  21. public bool Enabled { get; set; } = true;
  22. public bool ImportGraph { get; set; } = true;
  23. public string Name => "LSTM + CRF NER";
  24. HyperParams hp;
  25. Dictionary<string, int> vocab_tags = new Dictionary<string, int>();
  26. int nwords, nchars, ntags;
  27. CoNLLDataset dev, train;
  28. public bool Run()
  29. {
  30. PrepareData();
  31. var graph = tf.Graph().as_default();
  32. tf.train.import_meta_graph("graph/lstm_crf_ner.meta");
  33. var init = tf.global_variables_initializer();
  34. with(tf.Session(), sess =>
  35. {
  36. sess.run(init);
  37. foreach (var epoch in range(hp.epochs))
  38. {
  39. print($"Epoch {epoch + 1} out of {hp.epochs}");
  40. }
  41. });
  42. return true;
  43. }
  44. public void PrepareData()
  45. {
  46. hp = new HyperParams("LstmCrfNer")
  47. {
  48. epochs = 15,
  49. dropout = 0.5f,
  50. batch_size = 20,
  51. lr_method = "adam",
  52. lr = 0.001f,
  53. lr_decay = 0.9f,
  54. clip = false,
  55. epoch_no_imprv = 3,
  56. hidden_size_char = 100,
  57. hidden_size_lstm = 300
  58. };
  59. hp.filepath_dev = hp.filepath_test = hp.filepath_train = Path.Combine(hp.data_root_dir, "test.txt");
  60. // Loads vocabulary, processing functions and embeddings
  61. hp.filepath_words = Path.Combine(hp.data_root_dir, "words.txt");
  62. hp.filepath_tags = Path.Combine(hp.data_root_dir, "tags.txt");
  63. hp.filepath_chars = Path.Combine(hp.data_root_dir, "chars.txt");
  64. // 1. vocabulary
  65. /*vocab_tags = load_vocab(hp.filepath_tags);
  66. nwords = vocab_words.Count;
  67. nchars = vocab_chars.Count;
  68. ntags = vocab_tags.Count;*/
  69. // 2. get processing functions that map str -> id
  70. dev = new CoNLLDataset(hp.filepath_dev, hp);
  71. train = new CoNLLDataset(hp.filepath_train, hp);
  72. }
  73. }
  74. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。