You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

QueueTest.cs 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.UnitTest;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.Basics
  10. {
  11. [TestClass]
  12. public class QueueTest : GraphModeTestBase
  13. {
  14. [TestMethod]
  15. public void PaddingFIFOQueue()
  16. {
  17. var numbers = tf.placeholder(tf.int32);
  18. var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1));
  19. var enqueue = queue.enqueue(numbers);
  20. var dequeue_many = queue.dequeue_many(n: 3);
  21. using(var sess = tf.Session())
  22. {
  23. sess.run(enqueue, (numbers, new[] { 1 }));
  24. sess.run(enqueue, (numbers, new[] { 2, 3 }));
  25. sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
  26. var result = sess.run(dequeue_many[0]);
  27. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
  28. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
  29. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
  30. }
  31. }
  32. [TestMethod]
  33. public void FIFOQueue()
  34. {
  35. // create a first in first out queue with capacity up to 2
  36. // and data type set as int32
  37. var queue = tf.FIFOQueue(2, tf.int32);
  38. // init queue, push 3 elements into queue.
  39. var init = queue.enqueue_many(new[] { 10, 20 });
  40. // pop out the first element
  41. var x = queue.dequeue();
  42. // add 1
  43. var y = x + 1;
  44. // push back into queue
  45. var inc = queue.enqueue(y);
  46. using (var sess = tf.Session())
  47. {
  48. // init queue
  49. init.run();
  50. // pop out first element and push back calculated y
  51. (int dequeued, _) = sess.run((x, inc));
  52. Assert.AreEqual(10, dequeued);
  53. (dequeued, _) = sess.run((x, inc));
  54. Assert.AreEqual(20, dequeued);
  55. (dequeued, _) = sess.run((x, inc));
  56. Assert.AreEqual(11, dequeued);
  57. (dequeued, _) = sess.run((x, inc));
  58. Assert.AreEqual(21, dequeued);
  59. // thread will hang or block if you run sess.run(x) again
  60. // until queue has more element.
  61. }
  62. }
  63. [TestMethod]
  64. public void PriorityQueue()
  65. {
  66. var queue = tf.PriorityQueue(3, tf.@string);
  67. var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
  68. var x = queue.dequeue();
  69. using (var sess = tf.Session())
  70. {
  71. init.run();
  72. var result = sess.run(x);
  73. Assert.AreEqual(result[0].GetInt64(), 2L);
  74. result = sess.run(x);
  75. Assert.AreEqual(result[0].GetInt64(), 3L);
  76. result = sess.run(x);
  77. Assert.AreEqual(result[0].GetInt64(), 4L);
  78. }
  79. }
  80. [TestMethod]
  81. public void RandomShuffleQueue()
  82. {
  83. var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
  84. var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
  85. var x = queue.dequeue();
  86. string results = "";
  87. using (var sess = tf.Session())
  88. {
  89. init.run();
  90. foreach(var i in range(9))
  91. results += (int)sess.run(x) + ".";
  92. // output in random order
  93. Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
  94. }
  95. }
  96. }
  97. }