You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow.UnitTest;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.Dataset
  8. {
  9. [TestClass]
  10. public class DatasetTest : EagerModeTestBase
  11. {
  12. [TestMethod]
  13. public void Range()
  14. {
  15. int iStep = 0;
  16. long value = 0;
  17. var dataset = tf.data.Dataset.range(3);
  18. foreach(var (step, item) in enumerate(dataset))
  19. {
  20. Assert.AreEqual(iStep, step);
  21. iStep++;
  22. Assert.AreEqual(value, (long)item.Item1);
  23. value++;
  24. }
  25. }
  26. [TestMethod]
  27. public void Prefetch()
  28. {
  29. int iStep = 0;
  30. long value = 1;
  31. var dataset = tf.data.Dataset.range(1, 5, 2);
  32. dataset = dataset.prefetch(2);
  33. foreach (var (step, item) in enumerate(dataset))
  34. {
  35. Assert.AreEqual(iStep, step);
  36. iStep++;
  37. Assert.AreEqual(value, (long)item.Item1);
  38. value += 2;
  39. }
  40. }
  41. [TestMethod]
  42. public void FromTensorSlices()
  43. {
  44. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  45. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  46. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  47. int n = 0;
  48. foreach (var (item_x, item_y) in dataset)
  49. {
  50. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  51. n += 1;
  52. }
  53. Assert.AreEqual(5, n);
  54. }
  55. [TestMethod]
  56. public void Shard()
  57. {
  58. long value = 0;
  59. var dataset1 = tf.data.Dataset.range(10);
  60. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  61. foreach (var item in dataset2)
  62. {
  63. Assert.AreEqual(value, (long)item.Item1);
  64. value += 3;
  65. }
  66. value = 1;
  67. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  68. foreach (var item in dataset3)
  69. {
  70. Assert.AreEqual(value, (long)item.Item1);
  71. value += 3;
  72. }
  73. }
  74. }
  75. }