You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow.UnitTest;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.Dataset
  8. {
  9. [TestClass]
  10. public class DatasetTest : EagerModeTestBase
  11. {
  12. [TestMethod]
  13. public void Range()
  14. {
  15. int iStep = 0;
  16. long value = 0;
  17. var dataset = tf.data.Dataset.range(3);
  18. foreach(var (step, item) in enumerate(dataset))
  19. {
  20. Assert.AreEqual(iStep, step);
  21. iStep++;
  22. Assert.AreEqual(value, (long)item.Item1);
  23. value++;
  24. }
  25. }
  26. [TestMethod]
  27. public void Prefetch()
  28. {
  29. int iStep = 0;
  30. long value = 1;
  31. var dataset = tf.data.Dataset.range(1, 5, 2);
  32. dataset = dataset.prefetch(2);
  33. foreach (var (step, item) in enumerate(dataset))
  34. {
  35. Assert.AreEqual(iStep, step);
  36. iStep++;
  37. Assert.AreEqual(value, (long)item.Item1);
  38. value += 2;
  39. }
  40. }
  41. [TestMethod]
  42. public void FromTensorSlices()
  43. {
  44. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  45. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  46. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  47. int n = 0;
  48. foreach (var (item_x, item_y) in dataset)
  49. {
  50. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  51. n += 1;
  52. }
  53. Assert.AreEqual(5, n);
  54. }
  55. [TestMethod]
  56. public void Shard()
  57. {
  58. long value = 0;
  59. var dataset1 = tf.data.Dataset.range(10);
  60. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  61. foreach (var item in dataset2)
  62. {
  63. Assert.AreEqual(value, (long)item.Item1);
  64. value += 3;
  65. }
  66. value = 1;
  67. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  68. foreach (var item in dataset3)
  69. {
  70. Assert.AreEqual(value, (long)item.Item1);
  71. value += 3;
  72. }
  73. }
  74. [TestMethod]
  75. public void Skip()
  76. {
  77. long value = 7;
  78. var dataset = tf.data.Dataset.range(10);
  79. dataset = dataset.skip(7);
  80. foreach (var item in dataset)
  81. {
  82. Assert.AreEqual(value, (long)item.Item1);
  83. value ++;
  84. }
  85. }
  86. }
  87. }