You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Dataset
  6. {
  7. [TestClass]
  8. public class DatasetTest : EagerModeTestBase
  9. {
  10. [TestMethod]
  11. public void Range()
  12. {
  13. int iStep = 0;
  14. long value = 0;
  15. var dataset = tf.data.Dataset.range(3);
  16. foreach (var (step, item) in enumerate(dataset))
  17. {
  18. Assert.AreEqual(iStep, step);
  19. iStep++;
  20. Assert.AreEqual(value, (long)item.Item1);
  21. value++;
  22. }
  23. }
  24. [TestMethod]
  25. public void Prefetch()
  26. {
  27. int iStep = 0;
  28. long value = 1;
  29. var dataset = tf.data.Dataset.range(1, 5, 2);
  30. dataset = dataset.prefetch(2);
  31. foreach (var (step, item) in enumerate(dataset))
  32. {
  33. Assert.AreEqual(iStep, step);
  34. iStep++;
  35. Assert.AreEqual(value, (long)item.Item1);
  36. value += 2;
  37. }
  38. }
  39. [TestMethod]
  40. public void FromTensorSlices()
  41. {
  42. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  43. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  44. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  45. int n = 0;
  46. foreach (var (item_x, item_y) in dataset)
  47. {
  48. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  49. n += 1;
  50. }
  51. Assert.AreEqual(5, n);
  52. }
  53. [TestMethod]
  54. public void FromTensor()
  55. {
  56. var X = new[] { 2013, 2014, 2015, 2016, 2017 };
  57. var dataset = tf.data.Dataset.from_tensors(X);
  58. int n = 0;
  59. foreach (var x in dataset)
  60. {
  61. Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
  62. n += 1;
  63. }
  64. Assert.AreEqual(1, n);
  65. }
  66. [TestMethod]
  67. public void Shard()
  68. {
  69. long value = 0;
  70. var dataset1 = tf.data.Dataset.range(10);
  71. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  72. foreach (var item in dataset2)
  73. {
  74. Assert.AreEqual(value, (long)item.Item1);
  75. value += 3;
  76. }
  77. value = 1;
  78. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  79. foreach (var item in dataset3)
  80. {
  81. Assert.AreEqual(value, (long)item.Item1);
  82. value += 3;
  83. }
  84. }
  85. [TestMethod]
  86. public void Skip()
  87. {
  88. long value = 7;
  89. var dataset = tf.data.Dataset.range(10);
  90. dataset = dataset.skip(7);
  91. foreach (var item in dataset)
  92. {
  93. Assert.AreEqual(value, (long)item.Item1);
  94. value++;
  95. }
  96. }
  97. [TestMethod]
  98. public void Map()
  99. {
  100. long value = 0;
  101. var dataset = tf.data.Dataset.range(0, 2);
  102. dataset = dataset.map(x => x[0] + 10);
  103. foreach (var item in dataset)
  104. {
  105. Assert.AreEqual(value + 10, (long)item.Item1);
  106. value++;
  107. }
  108. }
  109. [TestMethod]
  110. public void Cache()
  111. {
  112. long value = 0;
  113. var dataset = tf.data.Dataset.range(5);
  114. dataset = dataset.cache();
  115. foreach (var item in dataset)
  116. {
  117. Assert.AreEqual(value, (long)item.Item1);
  118. value++;
  119. }
  120. }
  121. [TestMethod]
  122. public void Cardinality()
  123. {
  124. var dataset = tf.data.Dataset.range(10);
  125. var cardinality = dataset.cardinality();
  126. Assert.AreEqual(cardinality.numpy(), 10L);
  127. dataset = dataset.map(x => x[0] + 1);
  128. cardinality = dataset.cardinality();
  129. Assert.AreEqual(cardinality.numpy(), 10L);
  130. }
  131. [TestMethod]
  132. public void CardinalityWithAutoTune()
  133. {
  134. var dataset = tf.data.Dataset.range(10);
  135. dataset = dataset.map(x => x, num_parallel_calls: -1);
  136. var cardinality = dataset.cardinality();
  137. Assert.AreEqual(cardinality.numpy(), 10L);
  138. }
  139. [TestMethod]
  140. public void CardinalityWithRepeat()
  141. {
  142. var dataset = tf.data.Dataset.range(10);
  143. dataset = dataset.repeat();
  144. var cardinality = dataset.cardinality();
  145. Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy());
  146. dataset = dataset.filter(x => true);
  147. cardinality = dataset.cardinality();
  148. Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy());
  149. }
  150. [TestMethod]
  151. public void Shuffle()
  152. {
  153. tf.set_random_seed(1234);
  154. var dataset = tf.data.Dataset.range(3);
  155. var shuffled = dataset.shuffle(3);
  156. var zipped = tf.data.Dataset.zip(dataset, shuffled);
  157. bool allEqual = true;
  158. foreach (var item in zipped)
  159. {
  160. if (item.Item1 != item.Item2)
  161. allEqual = false;
  162. }
  163. Assert.IsFalse(allEqual);
  164. }
  165. }
  166. }