You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

QueueTest.cs 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow;
  4. using Tensorflow.UnitTest;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.Basics
  7. {
  8. [TestClass]
  9. public class QueueTest : GraphModeTestBase
  10. {
  11. [TestMethod]
  12. public void PaddingFIFOQueue()
  13. {
  14. var numbers = tf.placeholder(tf.int32);
  15. var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1));
  16. var enqueue = queue.enqueue(numbers);
  17. var dequeue_many = queue.dequeue_many(n: 3);
  18. using (var sess = tf.Session())
  19. {
  20. sess.run(enqueue, (numbers, new[] { 1 }));
  21. sess.run(enqueue, (numbers, new[] { 2, 3 }));
  22. sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
  23. var result = sess.run(dequeue_many[0]);
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
  25. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
  26. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
  27. }
  28. }
  29. [TestMethod]
  30. public void FIFOQueue()
  31. {
  32. // create a first in first out queue with capacity up to 2
  33. // and data type set as int32
  34. var queue = tf.FIFOQueue(2, tf.int32);
  35. // init queue, push 3 elements into queue.
  36. var init = queue.enqueue_many(new[] { 10, 20 });
  37. // pop out the first element
  38. var x = queue.dequeue();
  39. // add 1
  40. var y = x + 1;
  41. // push back into queue
  42. var inc = queue.enqueue(y);
  43. using (var sess = tf.Session())
  44. {
  45. // init queue
  46. init.run();
  47. // pop out first element and push back calculated y
  48. (int dequeued, _) = sess.run((x, inc));
  49. Assert.AreEqual(10, dequeued);
  50. (dequeued, _) = sess.run((x, inc));
  51. Assert.AreEqual(20, dequeued);
  52. (dequeued, _) = sess.run((x, inc));
  53. Assert.AreEqual(11, dequeued);
  54. (dequeued, _) = sess.run((x, inc));
  55. Assert.AreEqual(21, dequeued);
  56. // thread will hang or block if you run sess.run(x) again
  57. // until queue has more element.
  58. }
  59. }
  60. [TestMethod]
  61. public void PriorityQueue()
  62. {
  63. var queue = tf.PriorityQueue(3, tf.@string);
  64. var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
  65. var x = queue.dequeue();
  66. using (var sess = tf.Session())
  67. {
  68. init.run();
  69. var result = sess.run(x);
  70. Assert.AreEqual(result[0].GetInt64(), 2L);
  71. result = sess.run(x);
  72. Assert.AreEqual(result[0].GetInt64(), 3L);
  73. result = sess.run(x);
  74. Assert.AreEqual(result[0].GetInt64(), 4L);
  75. }
  76. }
  77. [TestMethod]
  78. public void RandomShuffleQueue()
  79. {
  80. var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
  81. var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
  82. var x = queue.dequeue();
  83. string results = "";
  84. using (var sess = tf.Session())
  85. {
  86. init.run();
  87. foreach (var i in range(9))
  88. results += (int)sess.run(x) + ".";
  89. // output in random order
  90. Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
  91. }
  92. }
  93. }
  94. }