You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TextClassificationTrain.cs 6.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using NumSharp;
  8. using Tensorflow;
  9. using Tensorflow.Keras.Engine;
  10. using TensorFlowNET.Examples.Text.cnn_models;
  11. using TensorFlowNET.Examples.TextClassification;
  12. using TensorFlowNET.Examples.Utility;
  13. using static Tensorflow.Python;
  14. namespace TensorFlowNET.Examples.CnnTextClassification
  15. {
  16. /// <summary>
  17. /// https://github.com/dongjun-Lee/text-classification-models-tf
  18. /// </summary>
  19. public class TextClassificationTrain : IExample
  20. {
  21. public int Priority => 100;
  22. public bool Enabled { get; set; } = false;
  23. public string Name => "Text Classification";
  24. public int? DataLimit = null;
  25. public bool ImportGraph { get; set; } = true;
  26. private string dataDir = "text_classification";
  27. private string dataFileName = "dbpedia_csv.tar.gz";
  28. public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
  29. private const int CHAR_MAX_LEN = 1014;
  30. private const int NUM_CLASS = 2;
  31. private const int BATCH_SIZE = 64;
  32. private const int NUM_EPOCHS = 10;
  33. protected float loss_value = 0;
  34. public bool Run()
  35. {
  36. PrepareData();
  37. return with(tf.Session(), sess =>
  38. {
  39. if (ImportGraph)
  40. return RunWithImportedGraph(sess);
  41. else
  42. return RunWithBuiltGraph(sess);
  43. });
  44. }
  45. protected virtual bool RunWithImportedGraph(Session sess)
  46. {
  47. var graph = tf.Graph().as_default();
  48. Console.WriteLine("Building dataset...");
  49. var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
  50. var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
  51. var meta_file = model_name + "_untrained.meta";
  52. tf.train.import_meta_graph(Path.Join("graph", meta_file));
  53. //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export
  54. var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
  55. var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
  56. double max_accuracy = 0;
  57. Tensor is_training = graph.get_operation_by_name("is_training");
  58. Tensor model_x = graph.get_operation_by_name("x");
  59. Tensor model_y = graph.get_operation_by_name("y");
  60. Tensor loss = graph.get_operation_by_name("Variable");
  61. Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy");
  62. foreach (var (x_batch, y_batch) in train_batches)
  63. {
  64. var train_feed_dict = new Hashtable
  65. {
  66. [model_x] = x_batch,
  67. [model_y] = y_batch,
  68. [is_training] = true,
  69. };
  70. //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
  71. }
  72. return false;
  73. }
  74. protected virtual bool RunWithBuiltGraph(Session session)
  75. {
  76. Console.WriteLine("Building dataset...");
  77. var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
  78. var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
  79. ITextClassificationModel model = null;
  80. switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
  81. {
  82. case "word_cnn":
  83. case "char_cnn":
  84. case "word_rnn":
  85. case "att_rnn":
  86. case "rcnn":
  87. throw new NotImplementedException();
  88. break;
  89. case "vd_cnn":
  90. model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
  91. break;
  92. }
  93. // todo train the model
  94. return false;
  95. }
  96. private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
  97. {
  98. int len = x.Length;
  99. int classes = y.Distinct().Count();
  100. int samples = len / classes;
  101. int train_size = int.Parse((samples * (1 - test_size)).ToString());
  102. var train_x = new List<int[]>();
  103. var valid_x = new List<int[]>();
  104. var train_y = new List<int>();
  105. var valid_y = new List<int>();
  106. for (int i = 0; i < classes; i++)
  107. {
  108. for (int j = 0; j < samples; j++)
  109. {
  110. int idx = i * samples + j;
  111. if (idx < train_size + samples * i)
  112. {
  113. train_x.Add(x[idx]);
  114. train_y.Add(y[idx]);
  115. }
  116. else
  117. {
  118. valid_x.Add(x[idx]);
  119. valid_y.Add(y[idx]);
  120. }
  121. }
  122. }
  123. return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
  124. }
  125. private IEnumerable<(NDArray, NDArray)> batch_iter(int[][] raw_inputs, int[] raw_outputs, int batch_size, int num_epochs)
  126. {
  127. var inputs = np.array(raw_inputs);
  128. var outputs = np.array(raw_outputs);
  129. var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1
  130. foreach (var epoch in range(num_epochs))
  131. {
  132. foreach (var batch_num in range(num_batches_per_epoch))
  133. {
  134. var start_index = batch_num * batch_size;
  135. var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
  136. yield return (inputs[$"{start_index}:{end_index}"], outputs[$"{start_index}:{end_index}"]);
  137. }
  138. }
  139. }
  140. public void PrepareData()
  141. {
  142. string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
  143. Web.Download(url, dataDir, dataFileName);
  144. Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
  145. if (ImportGraph)
  146. {
  147. // download graph meta data
  148. var meta_file = model_name + "_untrained.meta";
  149. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
  150. Web.Download(url, "graph", meta_file);
  151. }
  152. }
  153. }
  154. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。