You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TextClassificationTrain.cs 6.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using NumSharp;
  8. using Tensorflow;
  9. using Tensorflow.Keras.Engine;
  10. using TensorFlowNET.Examples.Text.cnn_models;
  11. using TensorFlowNET.Examples.TextClassification;
  12. using TensorFlowNET.Examples.Utility;
  13. namespace TensorFlowNET.Examples.CnnTextClassification
  14. {
  15. /// <summary>
  16. /// https://github.com/dongjun-Lee/text-classification-models-tf
  17. /// </summary>
  18. public class TextClassificationTrain : Python, IExample
  19. {
  20. public int Priority => 100;
  21. public bool Enabled { get; set; } = false;
  22. public string Name => "Text Classification";
  23. public int? DataLimit = null;
  24. public bool ImportGraph { get; set; } = true;
  25. private string dataDir = "text_classification";
  26. private string dataFileName = "dbpedia_csv.tar.gz";
  27. public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
  28. private const int CHAR_MAX_LEN = 1014;
  29. private const int NUM_CLASS = 2;
  30. private const int BATCH_SIZE = 64;
  31. private const int NUM_EPOCHS = 10;
  32. protected float loss_value = 0;
  33. public bool Run()
  34. {
  35. PrepareData();
  36. return with(tf.Session(), sess =>
  37. {
  38. if (ImportGraph)
  39. return RunWithImportedGraph(sess);
  40. else
  41. return RunWithBuiltGraph(sess);
  42. });
  43. }
  44. protected virtual bool RunWithImportedGraph(Session sess)
  45. {
  46. var graph = tf.Graph().as_default();
  47. Console.WriteLine("Building dataset...");
  48. var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
  49. var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
  50. var meta_file = model_name + "_untrained.meta";
  51. tf.train.import_meta_graph(Path.Join("graph", meta_file));
  52. //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export
  53. var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
  54. var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
  55. double max_accuracy = 0;
  56. Tensor is_training = graph.get_operation_by_name("is_training");
  57. Tensor model_x = graph.get_operation_by_name("x");
  58. Tensor model_y = graph.get_operation_by_name("y");
  59. Tensor loss = graph.get_operation_by_name("Variable");
  60. Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy");
  61. foreach (var (x_batch, y_batch) in train_batches)
  62. {
  63. var train_feed_dict = new Hashtable
  64. {
  65. [model_x] = x_batch,
  66. [model_y] = y_batch,
  67. [is_training] = true,
  68. };
  69. //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
  70. }
  71. return false;
  72. }
  73. protected virtual bool RunWithBuiltGraph(Session session)
  74. {
  75. Console.WriteLine("Building dataset...");
  76. var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
  77. var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
  78. ITextClassificationModel model = null;
  79. switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
  80. {
  81. case "word_cnn":
  82. case "char_cnn":
  83. case "word_rnn":
  84. case "att_rnn":
  85. case "rcnn":
  86. throw new NotImplementedException();
  87. break;
  88. case "vd_cnn":
  89. model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
  90. break;
  91. }
  92. // todo train the model
  93. return false;
  94. }
  95. private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
  96. {
  97. int len = x.Length;
  98. int classes = y.Distinct().Count();
  99. int samples = len / classes;
  100. int train_size = int.Parse((samples * (1 - test_size)).ToString());
  101. var train_x = new List<int[]>();
  102. var valid_x = new List<int[]>();
  103. var train_y = new List<int>();
  104. var valid_y = new List<int>();
  105. for (int i = 0; i < classes; i++)
  106. {
  107. for (int j = 0; j < samples; j++)
  108. {
  109. int idx = i * samples + j;
  110. if (idx < train_size + samples * i)
  111. {
  112. train_x.Add(x[idx]);
  113. train_y.Add(y[idx]);
  114. }
  115. else
  116. {
  117. valid_x.Add(x[idx]);
  118. valid_y.Add(y[idx]);
  119. }
  120. }
  121. }
  122. return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
  123. }
  124. private IEnumerable<(NDArray, NDArray)> batch_iter(int[][] raw_inputs, int[] raw_outputs, int batch_size, int num_epochs)
  125. {
  126. var inputs = np.array(raw_inputs);
  127. var outputs = np.array(raw_outputs);
  128. var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1
  129. foreach (var epoch in range(num_epochs))
  130. {
  131. foreach (var batch_num in range(num_batches_per_epoch))
  132. {
  133. var start_index = batch_num * batch_size;
  134. var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
  135. yield return (inputs[$"{start_index}:{end_index}"], outputs[$"{start_index}:{end_index}"]);
  136. }
  137. }
  138. }
  139. public void PrepareData()
  140. {
  141. string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
  142. Web.Download(url, dataDir, dataFileName);
  143. Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
  144. if (ImportGraph)
  145. {
  146. // download graph meta data
  147. var meta_file = model_name + "_untrained.meta";
  148. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
  149. Web.Download(url, "graph", meta_file);
  150. }
  151. }
  152. }
  153. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。