You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.NumPy;
  6. using static Tensorflow.Binding;
  7. using static Tensorflow.KerasApi;
  8. namespace TensorFlowNET.UnitTest.Dataset
  9. {
  10. [TestClass]
  11. public class DatasetTest : EagerModeTestBase
  12. {
  13. [TestMethod]
  14. public void Range()
  15. {
  16. int iStep = 0;
  17. long value = 0;
  18. var dataset = tf.data.Dataset.range(3);
  19. foreach (var (step, item) in enumerate(dataset))
  20. {
  21. Assert.AreEqual(iStep, step);
  22. iStep++;
  23. Assert.AreEqual(value, (long)item.Item1);
  24. value++;
  25. }
  26. }
  27. [TestMethod]
  28. public void Prefetch()
  29. {
  30. int iStep = 0;
  31. long value = 1;
  32. var dataset = tf.data.Dataset.range(1, 5, 2);
  33. dataset = dataset.prefetch(2);
  34. foreach (var (step, item) in enumerate(dataset))
  35. {
  36. Assert.AreEqual(iStep, step);
  37. iStep++;
  38. Assert.AreEqual(value, (long)item.Item1);
  39. value += 2;
  40. }
  41. }
  42. [TestMethod]
  43. public void FromTensorSlices()
  44. {
  45. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  46. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  47. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  48. int n = 0;
  49. foreach (var (item_x, item_y) in dataset)
  50. {
  51. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  52. n += 1;
  53. }
  54. Assert.AreEqual(5, n);
  55. }
  56. [TestMethod]
  57. public void FromTensor()
  58. {
  59. var X = new[] { 2013, 2014, 2015, 2016, 2017 };
  60. var dataset = tf.data.Dataset.from_tensors(X);
  61. int n = 0;
  62. foreach (var x in dataset)
  63. {
  64. Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
  65. n += 1;
  66. }
  67. Assert.AreEqual(1, n);
  68. }
  69. [TestMethod]
  70. public void Shard()
  71. {
  72. long value = 0;
  73. var dataset1 = tf.data.Dataset.range(10);
  74. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  75. foreach (var item in dataset2)
  76. {
  77. Assert.AreEqual(value, (long)item.Item1);
  78. value += 3;
  79. }
  80. value = 1;
  81. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  82. foreach (var item in dataset3)
  83. {
  84. Assert.AreEqual(value, (long)item.Item1);
  85. value += 3;
  86. }
  87. }
  88. [TestMethod]
  89. public void Skip()
  90. {
  91. long value = 7;
  92. var dataset = tf.data.Dataset.range(10);
  93. dataset = dataset.skip(7);
  94. foreach (var item in dataset)
  95. {
  96. Assert.AreEqual(value, (long)item.Item1);
  97. value++;
  98. }
  99. }
  100. [TestMethod]
  101. public void Map()
  102. {
  103. long value = 0;
  104. var dataset = tf.data.Dataset.range(0, 2);
  105. dataset = dataset.map(x => x[0] + 10);
  106. foreach (var item in dataset)
  107. {
  108. Assert.AreEqual(value + 10, (long)item.Item1);
  109. value++;
  110. }
  111. }
  112. [TestMethod]
  113. public void Cache()
  114. {
  115. long value = 0;
  116. var dataset = tf.data.Dataset.range(5);
  117. dataset = dataset.cache();
  118. foreach (var item in dataset)
  119. {
  120. Assert.AreEqual(value, (long)item.Item1);
  121. value++;
  122. }
  123. }
  124. [TestMethod]
  125. public void Cardinality()
  126. {
  127. var dataset = tf.data.Dataset.range(10);
  128. var cardinality = dataset.cardinality();
  129. Assert.AreEqual(cardinality.numpy(), 10L);
  130. dataset = dataset.map(x => x[0] + 1);
  131. cardinality = dataset.cardinality();
  132. Assert.AreEqual(cardinality.numpy(), 10L);
  133. }
  134. [TestMethod]
  135. public void CardinalityWithAutoTune()
  136. {
  137. var dataset = tf.data.Dataset.range(10);
  138. dataset = dataset.map(x => x, num_parallel_calls: -1);
  139. var cardinality = dataset.cardinality();
  140. Assert.AreEqual(cardinality.numpy(), 10L);
  141. }
  142. [TestMethod]
  143. public void CardinalityWithRepeat()
  144. {
  145. var dataset = tf.data.Dataset.range(10);
  146. dataset = dataset.repeat();
  147. var cardinality = dataset.cardinality();
  148. Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy());
  149. dataset = dataset.filter(x => true);
  150. cardinality = dataset.cardinality();
  151. Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy());
  152. }
  153. [TestMethod]
  154. public void Shuffle()
  155. {
  156. tf.set_random_seed(1234);
  157. var dataset = tf.data.Dataset.range(3);
  158. var shuffled = dataset.shuffle(3);
  159. var zipped = tf.data.Dataset.zip(dataset, shuffled);
  160. bool allEqual = true;
  161. foreach (var item in zipped)
  162. {
  163. if (item.Item1 != item.Item2)
  164. allEqual = false;
  165. }
  166. Assert.IsFalse(allEqual);
  167. }
  168. [Ignore]
  169. [TestMethod]
  170. public void GetData()
  171. {
  172. var vocab_size = 20000; // Only consider the top 20k words
  173. var maxlen = 200; // Only consider the first 200 words of each movie review
  174. var dataset = keras.datasets.imdb.load_data(num_words: vocab_size, maxlen: maxlen);
  175. var x_train = dataset.Train.Item1;
  176. var y_train = dataset.Train.Item2;
  177. var x_val = dataset.Test.Item1;
  178. var y_val = dataset.Test.Item2;
  179. x_train = keras.preprocessing.sequence.pad_sequences(RemoveZeros(x_train), maxlen: maxlen);
  180. x_val = keras.preprocessing.sequence.pad_sequences(RemoveZeros(x_val), maxlen: maxlen);
  181. print(len(x_train) + " Training sequences");
  182. print(len(x_val) + " Validation sequences");
  183. }
  184. IEnumerable<int[]> RemoveZeros(NDArray data)
  185. {
  186. var data_array = (int[,])data.ToMultiDimArray<int>();
  187. List<int[]> new_data = new List<int[]>();
  188. for (var i = 0; i < data_array.GetLength(0); i++)
  189. {
  190. List<int> new_array = new List<int>();
  191. for (var j = 0; j < data_array.GetLength(1); j++)
  192. {
  193. if (data_array[i, j] == 0)
  194. break;
  195. else
  196. new_array.Add(data_array[i, j]);
  197. }
  198. new_data.Add(new_array.ToArray());
  199. }
  200. return new_data;
  201. }
  202. }
  203. }