You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

QueueTest.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Basics
  6. {
  7. [TestClass]
  8. public class QueueTest : GraphModeTestBase
  9. {
  10. [TestMethod]
  11. public void PaddingFIFOQueue()
  12. {
  13. var numbers = tf.placeholder(tf.int32);
  14. var queue = tf.PaddingFIFOQueue(10, tf.int32, new Shape(-1));
  15. var enqueue = queue.enqueue(numbers);
  16. var dequeue_many = queue.dequeue_many(n: 3);
  17. using (var sess = tf.Session())
  18. {
  19. sess.run(enqueue, (numbers, new[] { 1 }));
  20. sess.run(enqueue, (numbers, new[] { 2, 3 }));
  21. sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
  22. var result = sess.run(dequeue_many[0]);
  23. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
  25. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
  26. }
  27. }
  28. [TestMethod]
  29. public void FIFOQueue()
  30. {
  31. // create a first in first out queue with capacity up to 2
  32. // and data type set as int32
  33. var queue = tf.FIFOQueue(2, tf.int32);
  34. // init queue, push 3 elements into queue.
  35. var init = queue.enqueue_many(new[] { 10, 20 });
  36. // pop out the first element
  37. var x = queue.dequeue();
  38. // add 1
  39. var y = x + 1;
  40. // push back into queue
  41. var inc = queue.enqueue(y);
  42. using (var sess = tf.Session())
  43. {
  44. // init queue
  45. init.run();
  46. // pop out first element and push back calculated y
  47. (int dequeued, _) = sess.run((x, inc));
  48. Assert.AreEqual(10, dequeued);
  49. (dequeued, _) = sess.run((x, inc));
  50. Assert.AreEqual(20, dequeued);
  51. (dequeued, _) = sess.run((x, inc));
  52. Assert.AreEqual(11, dequeued);
  53. (dequeued, _) = sess.run((x, inc));
  54. Assert.AreEqual(21, dequeued);
  55. // thread will hang or block if you run sess.run(x) again
  56. // until queue has more element.
  57. }
  58. }
  59. [TestMethod]
  60. public void PriorityQueue()
  61. {
  62. var queue = tf.PriorityQueue(3, tf.@string);
  63. var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
  64. var x = queue.dequeue();
  65. using (var sess = tf.Session())
  66. {
  67. init.run();
  68. var result = sess.run(x);
  69. Assert.AreEqual(result[0], 2L);
  70. result = sess.run(x);
  71. Assert.AreEqual(result[0], 3L);
  72. result = sess.run(x);
  73. Assert.AreEqual(result[0], 4L);
  74. }
  75. }
  76. [TestMethod]
  77. public void RandomShuffleQueue()
  78. {
  79. var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
  80. var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
  81. var x = queue.dequeue();
  82. string results = "";
  83. using (var sess = tf.Session())
  84. {
  85. init.run();
  86. foreach (var i in range(9))
  87. results += (int)sess.run(x) + ".";
  88. // output in random order
  89. Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
  90. }
  91. }
  92. }
  93. }