You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.Keras;
  8. using Tensorflow.UnitTest;
  9. using static Tensorflow.Binding;
  10. namespace TensorFlowNET.UnitTest.Dataset
  11. {
  12. [TestClass]
  13. public class DatasetTest : EagerModeTestBase
  14. {
  15. [TestMethod]
  16. public void Range()
  17. {
  18. int iStep = 0;
  19. long value = 0;
  20. var dataset = tf.data.Dataset.range(3);
  21. foreach(var (step, item) in enumerate(dataset))
  22. {
  23. Assert.AreEqual(iStep, step);
  24. iStep++;
  25. Assert.AreEqual(value, (long)item.Item1);
  26. value++;
  27. }
  28. }
  29. [TestMethod]
  30. public void Prefetch()
  31. {
  32. int iStep = 0;
  33. long value = 1;
  34. var dataset = tf.data.Dataset.range(1, 5, 2);
  35. dataset = dataset.prefetch(2);
  36. foreach (var (step, item) in enumerate(dataset))
  37. {
  38. Assert.AreEqual(iStep, step);
  39. iStep++;
  40. Assert.AreEqual(value, (long)item.Item1);
  41. value += 2;
  42. }
  43. }
  44. [TestMethod]
  45. public void FromTensorSlices()
  46. {
  47. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  48. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  49. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  50. int n = 0;
  51. foreach (var (item_x, item_y) in dataset)
  52. {
  53. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  54. n += 1;
  55. }
  56. Assert.AreEqual(5, n);
  57. }
  58. [TestMethod]
  59. public void FromTensor()
  60. {
  61. var X = new[] { 2013, 2014, 2015, 2016, 2017 };
  62. var dataset = tf.data.Dataset.from_tensor(X);
  63. int n = 0;
  64. foreach (var x in dataset)
  65. {
  66. Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
  67. n += 1;
  68. }
  69. Assert.AreEqual(1, n);
  70. }
  71. [TestMethod]
  72. public void Shard()
  73. {
  74. long value = 0;
  75. var dataset1 = tf.data.Dataset.range(10);
  76. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  77. foreach (var item in dataset2)
  78. {
  79. Assert.AreEqual(value, (long)item.Item1);
  80. value += 3;
  81. }
  82. value = 1;
  83. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  84. foreach (var item in dataset3)
  85. {
  86. Assert.AreEqual(value, (long)item.Item1);
  87. value += 3;
  88. }
  89. }
  90. [TestMethod]
  91. public void Skip()
  92. {
  93. long value = 7;
  94. var dataset = tf.data.Dataset.range(10);
  95. dataset = dataset.skip(7);
  96. foreach (var item in dataset)
  97. {
  98. Assert.AreEqual(value, (long)item.Item1);
  99. value ++;
  100. }
  101. }
  102. [TestMethod]
  103. public void Map()
  104. {
  105. long value = 0;
  106. var dataset = tf.data.Dataset.range(0, 2);
  107. dataset = dataset.map(x => x + 10);
  108. foreach (var item in dataset)
  109. {
  110. Assert.AreEqual(value + 10, (long)item.Item1);
  111. value++;
  112. }
  113. }
  114. [TestMethod]
  115. public void Cache()
  116. {
  117. long value = 0;
  118. var dataset = tf.data.Dataset.range(5);
  119. dataset = dataset.cache();
  120. foreach (var item in dataset)
  121. {
  122. Assert.AreEqual(value, (long)item.Item1);
  123. value++;
  124. }
  125. }
  126. }
  127. }