You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TextClassificationTrain.cs 5.2 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.Keras.Engine;
  8. using TensorFlowNET.Examples.Text.cnn_models;
  9. using TensorFlowNET.Examples.TextClassification;
  10. using TensorFlowNET.Examples.Utility;
  11. namespace TensorFlowNET.Examples.CnnTextClassification
  12. {
  13. /// <summary>
  14. /// https://github.com/dongjun-Lee/text-classification-models-tf
  15. /// </summary>
  16. public class TextClassificationTrain : Python, IExample
  17. {
  18. public int Priority => 100;
  19. public bool Enabled { get; set; } = false;
  20. public string Name => "Text Classification";
  21. public int? DataLimit = null;
  22. public bool ImportGraph { get; set; } = true;
  23. private string dataDir = "text_classification";
  24. private string dataFileName = "dbpedia_csv.tar.gz";
  25. public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
  26. private const int CHAR_MAX_LEN = 1014;
  27. private const int NUM_CLASS = 2;
  28. protected float loss_value = 0;
  29. public bool Run()
  30. {
  31. PrepareData();
  32. return with(tf.Session(), sess =>
  33. {
  34. if (ImportGraph)
  35. return RunWithImportedGraph(sess);
  36. else
  37. return RunWithBuiltGraph(sess);
  38. });
  39. }
  40. protected virtual bool RunWithImportedGraph(Session sess)
  41. {
  42. var graph = tf.Graph().as_default();
  43. Console.WriteLine("Building dataset...");
  44. var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
  45. var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
  46. var meta_file = model_name + "_untrained.meta";
  47. tf.train.import_meta_graph(Path.Join("graph", meta_file));
  48. //sess.run(tf.global_variables_initializer());
  49. Tensor is_training = graph.get_operation_by_name("is_training");
  50. Tensor model_x = graph.get_operation_by_name("x");
  51. Tensor model_y = graph.get_operation_by_name("y");
  52. //Tensor loss = graph.get_operation_by_name("loss");
  53. //Tensor accuracy = graph.get_operation_by_name("accuracy");
  54. return false;
  55. }
  56. protected virtual bool RunWithBuiltGraph(Session session)
  57. {
  58. Console.WriteLine("Building dataset...");
  59. var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
  60. var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
  61. ITextClassificationModel model = null;
  62. switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
  63. {
  64. case "word_cnn":
  65. case "char_cnn":
  66. case "word_rnn":
  67. case "att_rnn":
  68. case "rcnn":
  69. throw new NotImplementedException();
  70. break;
  71. case "vd_cnn":
  72. model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
  73. break;
  74. }
  75. // todo train the model
  76. return false;
  77. }
  78. private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
  79. {
  80. int len = x.Length;
  81. int classes = y.Distinct().Count();
  82. int samples = len / classes;
  83. int train_size = int.Parse((samples * (1 - test_size)).ToString());
  84. var train_x = new List<int[]>();
  85. var valid_x = new List<int[]>();
  86. var train_y = new List<int>();
  87. var valid_y = new List<int>();
  88. for (int i = 0; i < classes; i++)
  89. {
  90. for (int j = 0; j < samples; j++)
  91. {
  92. int idx = i * samples + j;
  93. if (idx < train_size + samples * i)
  94. {
  95. train_x.Add(x[idx]);
  96. train_y.Add(y[idx]);
  97. }
  98. else
  99. {
  100. valid_x.Add(x[idx]);
  101. valid_y.Add(y[idx]);
  102. }
  103. }
  104. }
  105. return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
  106. }
  107. public void PrepareData()
  108. {
  109. string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
  110. Web.Download(url, dataDir, dataFileName);
  111. Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
  112. if (ImportGraph)
  113. {
  114. // download graph meta data
  115. var meta_file = model_name + "_untrained.meta";
  116. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
  117. Web.Download(url, "graph", meta_file);
  118. }
  119. }
  120. }
  121. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。