You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BinaryTextClassification.cs 5.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using Tensorflow;
  5. using Newtonsoft.Json;
  6. using System.Linq;
  7. using NumSharp;
  8. namespace TensorFlowNET.Examples
  9. {
  10. /// <summary>
  11. /// This example classifies movie reviews as positive or negative using the text of the review.
  12. /// This is a binary—or two-class—classification, an important and widely applicable kind of machine learning problem.
  13. /// https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/basic_text_classification.ipynb
  14. /// </summary>
  15. public class BinaryTextClassification : IExample
  16. {
  17. public bool Enabled { get; set; } = false;
  18. public string Name => "Binary Text Classification";
  19. public bool IsImportingGraph { get; set; } = true;
  20. string dir = "binary_text_classification";
  21. string dataFile = "imdb.zip";
  22. NDArray train_data, train_labels, test_data, test_labels;
  23. public bool Run()
  24. {
  25. PrepareData();
  26. Console.WriteLine($"Training entries: {train_data.shape[0]}, labels: {train_labels.shape[0]}");
  27. // A dictionary mapping words to an integer index
  28. var word_index = GetWordIndex();
  29. /*train_data = keras.preprocessing.sequence.pad_sequences(train_data,
  30. value: word_index["<PAD>"],
  31. padding: "post",
  32. maxlen: 256);
  33. test_data = keras.preprocessing.sequence.pad_sequences(test_data,
  34. value: word_index["<PAD>"],
  35. padding: "post",
  36. maxlen: 256);*/
  37. // input shape is the vocabulary count used for the movie reviews (10,000 words)
  38. int vocab_size = 10000;
  39. var model = keras.Sequential();
  40. var layer = keras.layers.Embedding(vocab_size, 16);
  41. model.add(layer);
  42. return false;
  43. }
  44. public void PrepareData()
  45. {
  46. Directory.CreateDirectory(dir);
  47. // get model file
  48. string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
  49. Utility.Web.Download(url, dir, "imdb.zip");
  50. Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
  51. // prepare training dataset
  52. var x_train = ReadData(Path.Join(dir, "x_train.txt"));
  53. var labels_train = ReadData(Path.Join(dir, "y_train.txt"));
  54. var indices_train = ReadData(Path.Join(dir, "indices_train.txt"));
  55. x_train = x_train[indices_train];
  56. labels_train = labels_train[indices_train];
  57. var x_test = ReadData(Path.Join(dir, "x_test.txt"));
  58. var labels_test = ReadData(Path.Join(dir, "y_test.txt"));
  59. var indices_test = ReadData(Path.Join(dir, "indices_test.txt"));
  60. x_test = x_test[indices_test];
  61. labels_test = labels_test[indices_test];
  62. // not completed
  63. var xs = x_train.hstack(x_test);
  64. var labels = labels_train.hstack(labels_test);
  65. var idx = x_train.size;
  66. var y_train = labels_train;
  67. var y_test = labels_test;
  68. // convert x_train
  69. train_data = new NDArray(np.int32, (x_train.size, 256));
  70. /*for (int i = 0; i < x_train.size; i++)
  71. train_data[i] = x_train[i].Data<string>()[1].Split(',').Select(x => int.Parse(x)).ToArray();*/
  72. test_data = new NDArray(np.int32, (x_test.size, 256));
  73. /*for (int i = 0; i < x_test.size; i++)
  74. test_data[i] = x_test[i].Data<string>()[1].Split(',').Select(x => int.Parse(x)).ToArray();*/
  75. train_labels = y_train;
  76. test_labels = y_test;
  77. }
  78. private NDArray ReadData(string file)
  79. {
  80. var lines = File.ReadAllLines(file);
  81. var nd = new NDArray(lines[0].StartsWith("[") ? typeof(string) : np.int32, new Shape(lines.Length));
  82. if (lines[0].StartsWith("["))
  83. {
  84. for (int i = 0; i < lines.Length; i++)
  85. {
  86. /*var matches = Regex.Matches(lines[i], @"\d+\s*");
  87. var data = new int[matches.Count];
  88. for (int j = 0; j < data.Length; j++)
  89. data[j] = Convert.ToInt32(matches[j].Value);
  90. nd[i] = data.ToArray();*/
  91. nd[i] = lines[i].Substring(1, lines[i].Length - 2).Replace(" ", string.Empty);
  92. }
  93. }
  94. else
  95. {
  96. for (int i = 0; i < lines.Length; i++)
  97. nd[i] = Convert.ToInt32(lines[i]);
  98. }
  99. return nd;
  100. }
  101. private Dictionary<string, int> GetWordIndex()
  102. {
  103. var result = new Dictionary<string, int>();
  104. var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
  105. var dict = JsonConvert.DeserializeObject<Dictionary<string, int>>(json);
  106. dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
  107. result["<PAD>"] = 0;
  108. result["<START>"] = 1;
  109. result["<UNK>"] = 2; // unknown
  110. result["<UNUSED>"] = 3;
  111. return result;
  112. }
  113. public Graph ImportGraph()
  114. {
  115. throw new NotImplementedException();
  116. }
  117. public Graph BuildGraph()
  118. {
  119. throw new NotImplementedException();
  120. }
  121. public void Train(Session sess)
  122. {
  123. throw new NotImplementedException();
  124. }
  125. public void Predict(Session sess)
  126. {
  127. throw new NotImplementedException();
  128. }
  129. public void Test(Session sess)
  130. {
  131. throw new NotImplementedException();
  132. }
  133. }
  134. }