You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

QueueTest.cs 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Basics
  6. {
  7. [TestClass]
  8. public class QueueTest : GraphModeTestBase
  9. {
  10. [TestMethod]
  11. public void PaddingFIFOQueue()
  12. {
  13. var numbers = tf.placeholder(tf.int32);
  14. var queue = tf.PaddingFIFOQueue(10, tf.int32, new Shape(-1));
  15. var enqueue = queue.enqueue(numbers);
  16. var dequeue_many = queue.dequeue_many(n: 3);
  17. var sess = tf.Session();
  18. sess.run(enqueue, (numbers, new[] { 1 }));
  19. sess.run(enqueue, (numbers, new[] { 2, 3 }));
  20. sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
  21. var result = sess.run(dequeue_many[0]);
  22. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
  23. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
  25. }
  26. [TestMethod]
  27. public void FIFOQueue()
  28. {
  29. // create a first in first out queue with capacity up to 2
  30. // and data type set as int32
  31. var queue = tf.FIFOQueue(2, tf.int32);
  32. // init queue, push 3 elements into queue.
  33. var init = queue.enqueue_many(new[] { 10, 20 });
  34. // pop out the first element
  35. var x = queue.dequeue();
  36. // add 1
  37. var y = x + 1;
  38. // push back into queue
  39. var inc = queue.enqueue(y);
  40. var sess = tf.Session();
  41. // init queue
  42. init.run();
  43. // pop out first element and push back calculated y
  44. (int dequeued, _) = sess.run((x, inc));
  45. Assert.AreEqual(10, dequeued);
  46. (dequeued, _) = sess.run((x, inc));
  47. Assert.AreEqual(20, dequeued);
  48. (dequeued, _) = sess.run((x, inc));
  49. Assert.AreEqual(11, dequeued);
  50. (dequeued, _) = sess.run((x, inc));
  51. Assert.AreEqual(21, dequeued);
  52. // thread will hang or block if you run sess.run(x) again
  53. // until queue has more element.
  54. }
  55. [TestMethod]
  56. public void PriorityQueue()
  57. {
  58. var queue = tf.PriorityQueue(3, tf.@string);
  59. var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
  60. var x = queue.dequeue();
  61. var sess = tf.Session();
  62. init.run();
  63. var result = sess.run(x);
  64. Assert.AreEqual(result[0], 2L);
  65. result = sess.run(x);
  66. Assert.AreEqual(result[0], 3L);
  67. result = sess.run(x);
  68. Assert.AreEqual(result[0], 4L);
  69. }
  70. [TestMethod]
  71. public void RandomShuffleQueue()
  72. {
  73. var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
  74. var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
  75. var x = queue.dequeue();
  76. string results = "";
  77. var sess = tf.Session();
  78. init.run();
  79. foreach (var i in range(9))
  80. results += (int)sess.run(x) + ".";
  81. // output in random order
  82. Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
  83. }
  84. }
  85. }