You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BinaryTextClassification.cs 4.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using Tensorflow;
  5. using Newtonsoft.Json;
  6. using System.Linq;
  7. using System.Text.RegularExpressions;
  8. using NumSharp;
  9. using static Tensorflow.Python;
  10. namespace TensorFlowNET.Examples
  11. {
  12. /// <summary>
  13. /// This example classifies movie reviews as positive or negative using the text of the review.
  14. /// This is a binary—or two-class—classification, an important and widely applicable kind of machine learning problem.
  15. /// https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/basic_text_classification.ipynb
  16. /// </summary>
  17. public class BinaryTextClassification : IExample
  18. {
  19. public int Priority => 9;
  20. public bool Enabled { get; set; } = true;
  21. public string Name => "Binary Text Classification";
  22. public bool ImportGraph { get; set; } = true;
  23. string dir = "binary_text_classification";
  24. string dataFile = "imdb.zip";
  25. NDArray train_data, train_labels, test_data, test_labels;
  26. public bool Run()
  27. {
  28. PrepareData();
  29. Console.WriteLine($"Training entries: {train_data.size}, labels: {train_labels.size}");
  30. // A dictionary mapping words to an integer index
  31. var word_index = GetWordIndex();
  32. train_data = keras.preprocessing.sequence.pad_sequences(train_data,
  33. value: word_index["<PAD>"],
  34. padding: "post",
  35. maxlen: 256);
  36. test_data = keras.preprocessing.sequence.pad_sequences(test_data,
  37. value: word_index["<PAD>"],
  38. padding: "post",
  39. maxlen: 256);
  40. // input shape is the vocabulary count used for the movie reviews (10,000 words)
  41. int vocab_size = 10000;
  42. var model = keras.Sequential();
  43. model.add(keras.layers.Embedding(vocab_size, 16));
  44. return false;
  45. }
  46. public void PrepareData()
  47. {
  48. Directory.CreateDirectory(dir);
  49. // get model file
  50. string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
  51. Utility.Web.Download(url, dir, "imdb.zip");
  52. Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
  53. // prepare training dataset
  54. var x_train = ReadData(Path.Join(dir, "x_train.txt"));
  55. var labels_train = ReadData(Path.Join(dir, "y_train.txt"));
  56. var indices_train = ReadData(Path.Join(dir, "indices_train.txt"));
  57. x_train = x_train[indices_train];
  58. labels_train = labels_train[indices_train];
  59. var x_test = ReadData(Path.Join(dir, "x_test.txt"));
  60. var labels_test = ReadData(Path.Join(dir, "y_test.txt"));
  61. var indices_test = ReadData(Path.Join(dir, "indices_test.txt"));
  62. x_test = x_test[indices_test];
  63. labels_test = labels_test[indices_test];
  64. // not completed
  65. var xs = x_train.hstack<int>(x_test);
  66. var labels = labels_train.hstack<int>(labels_test);
  67. var idx = x_train.size;
  68. var y_train = labels_train;
  69. var y_test = labels_test;
  70. x_train = train_data;
  71. train_labels = y_train;
  72. test_data = x_test;
  73. test_labels = y_test;
  74. }
  75. private NDArray ReadData(string file)
  76. {
  77. var lines = File.ReadAllLines(file);
  78. var nd = new NDArray(lines[0].StartsWith("[") ? typeof(string) : np.int32, new Shape(lines.Length));
  79. if (lines[0].StartsWith("["))
  80. {
  81. for (int i = 0; i < lines.Length; i++)
  82. {
  83. /*var matches = Regex.Matches(lines[i], @"\d+\s*");
  84. var data = new int[matches.Count];
  85. for (int j = 0; j < data.Length; j++)
  86. data[j] = Convert.ToInt32(matches[j].Value);
  87. nd[i] = data.ToArray();*/
  88. nd[i] = lines[i].Substring(1, lines[i].Length - 2).Replace(" ", string.Empty);
  89. }
  90. }
  91. else
  92. {
  93. for (int i = 0; i < lines.Length; i++)
  94. nd[i] = Convert.ToInt32(lines[i]);
  95. }
  96. return nd;
  97. }
  98. private Dictionary<string, int> GetWordIndex()
  99. {
  100. var result = new Dictionary<string, int>();
  101. var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
  102. var dict = JsonConvert.DeserializeObject<Dictionary<string, int>>(json);
  103. dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
  104. result["<PAD>"] = 0;
  105. result["<START>"] = 1;
  106. result["<UNK>"] = 2; // unknown
  107. result["<UNUSED>"] = 3;
  108. return result;
  109. }
  110. }
  111. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。