You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. namespace TensorFlowNET.UnitTest.Dataset
  8. {
  9. [TestClass]
  10. public class DatasetTest : EagerModeTestBase
  11. {
  12. [TestMethod]
  13. public void Range()
  14. {
  15. int iStep = 0;
  16. long value = 0;
  17. var dataset = tf.data.Dataset.range(3);
  18. foreach (var (step, item) in enumerate(dataset))
  19. {
  20. Assert.AreEqual(iStep, step);
  21. iStep++;
  22. Assert.AreEqual(value, (long)item.Item1);
  23. value++;
  24. }
  25. }
  26. [TestMethod]
  27. public void Prefetch()
  28. {
  29. int iStep = 0;
  30. long value = 1;
  31. var dataset = tf.data.Dataset.range(1, 5, 2);
  32. dataset = dataset.prefetch(2);
  33. foreach (var (step, item) in enumerate(dataset))
  34. {
  35. Assert.AreEqual(iStep, step);
  36. iStep++;
  37. Assert.AreEqual(value, (long)item.Item1);
  38. value += 2;
  39. }
  40. }
  41. [TestMethod]
  42. public void FromTensorSlices()
  43. {
  44. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  45. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  46. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  47. int n = 0;
  48. foreach (var (item_x, item_y) in dataset)
  49. {
  50. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  51. n += 1;
  52. }
  53. Assert.AreEqual(5, n);
  54. }
  55. [TestMethod]
  56. public void FromTensor()
  57. {
  58. var X = new[] { 2013, 2014, 2015, 2016, 2017 };
  59. var dataset = tf.data.Dataset.from_tensors(X);
  60. int n = 0;
  61. foreach (var x in dataset)
  62. {
  63. Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
  64. n += 1;
  65. }
  66. Assert.AreEqual(1, n);
  67. }
  68. [TestMethod]
  69. public void Shard()
  70. {
  71. long value = 0;
  72. var dataset1 = tf.data.Dataset.range(10);
  73. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  74. foreach (var item in dataset2)
  75. {
  76. Assert.AreEqual(value, (long)item.Item1);
  77. value += 3;
  78. }
  79. value = 1;
  80. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  81. foreach (var item in dataset3)
  82. {
  83. Assert.AreEqual(value, (long)item.Item1);
  84. value += 3;
  85. }
  86. }
  87. [TestMethod]
  88. public void Skip()
  89. {
  90. long value = 7;
  91. var dataset = tf.data.Dataset.range(10);
  92. dataset = dataset.skip(7);
  93. foreach (var item in dataset)
  94. {
  95. Assert.AreEqual(value, (long)item.Item1);
  96. value++;
  97. }
  98. }
  99. [TestMethod]
  100. public void Map()
  101. {
  102. long value = 0;
  103. var dataset = tf.data.Dataset.range(0, 2);
  104. dataset = dataset.map(x => x[0] + 10);
  105. foreach (var item in dataset)
  106. {
  107. Assert.AreEqual(value + 10, (long)item.Item1);
  108. value++;
  109. }
  110. }
  111. [TestMethod]
  112. public void Cache()
  113. {
  114. long value = 0;
  115. var dataset = tf.data.Dataset.range(5);
  116. dataset = dataset.cache();
  117. foreach (var item in dataset)
  118. {
  119. Assert.AreEqual(value, (long)item.Item1);
  120. value++;
  121. }
  122. }
  123. [TestMethod]
  124. public void Cardinality()
  125. {
  126. var dataset = tf.data.Dataset.range(10);
  127. var cardinality = dataset.cardinality();
  128. Assert.AreEqual(cardinality.numpy(), 10L);
  129. dataset = dataset.map(x => x[0] + 1);
  130. cardinality = dataset.cardinality();
  131. Assert.AreEqual(cardinality.numpy(), 10L);
  132. }
  133. [TestMethod]
  134. public void CardinalityWithAutoTune()
  135. {
  136. var dataset = tf.data.Dataset.range(10);
  137. dataset = dataset.map(x => x, num_parallel_calls: -1);
  138. var cardinality = dataset.cardinality();
  139. Assert.AreEqual(cardinality.numpy(), 10L);
  140. }
  141. [TestMethod]
  142. public void CardinalityWithRepeat()
  143. {
  144. var dataset = tf.data.Dataset.range(10);
  145. dataset = dataset.repeat();
  146. var cardinality = dataset.cardinality();
  147. Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy());
  148. dataset = dataset.filter(x => true);
  149. cardinality = dataset.cardinality();
  150. Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy());
  151. }
  152. [TestMethod]
  153. public void Shuffle()
  154. {
  155. tf.set_random_seed(1234);
  156. var dataset = tf.data.Dataset.range(3);
  157. var shuffled = dataset.shuffle(3);
  158. var zipped = tf.data.Dataset.zip(dataset, shuffled);
  159. bool allEqual = true;
  160. foreach (var item in zipped)
  161. {
  162. if (item.Item1 != item.Item2)
  163. allEqual = false;
  164. }
  165. Assert.IsFalse(allEqual);
  166. }
  167. [TestMethod]
  168. public void GetData()
  169. {
  170. var vocab_size = 20000; // Only consider the top 20k words
  171. var maxlen = 200; // Only consider the first 200 words of each movie review
  172. var dataset = keras.datasets.imdb.load_data(num_words: vocab_size);
  173. var x_train = dataset.Train.Item1;
  174. var y_train = dataset.Train.Item2;
  175. var x_val = dataset.Test.Item1;
  176. var y_val = dataset.Test.Item2;
  177. print(len(x_train) + "Training sequences");
  178. print(len(x_val) + "Validation sequences");
  179. x_train = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_train, maxlen: maxlen);
  180. x_val = keras.preprocessing.sequence.pad_sequences((IEnumerable<int[]>)x_val, maxlen: maxlen);
  181. }
  182. }
  183. }