You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

QueueTest.cs 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class QueueTest
  12. {
  13. [TestMethod]
  14. public void PaddingFIFOQueue()
  15. {
  16. var numbers = tf.placeholder(tf.int32);
  17. var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1));
  18. var enqueue = queue.enqueue(numbers);
  19. var dequeue_many = queue.dequeue_many(n: 3);
  20. using(var sess = tf.Session())
  21. {
  22. sess.run(enqueue, (numbers, new[] { 1 }));
  23. sess.run(enqueue, (numbers, new[] { 2, 3 }));
  24. sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
  25. var result = sess.run(dequeue_many[0]);
  26. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
  27. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
  28. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
  29. }
  30. }
  31. [TestMethod]
  32. public void FIFOQueue()
  33. {
  34. // create a first in first out queue with capacity up to 2
  35. // and data type set as int32
  36. var queue = tf.FIFOQueue(2, tf.int32);
  37. // init queue, push 3 elements into queue.
  38. var init = queue.enqueue_many(new[] { 10, 20 });
  39. // pop out the first element
  40. var x = queue.dequeue();
  41. // add 1
  42. var y = x + 1;
  43. // push back into queue
  44. var inc = queue.enqueue(y);
  45. using (var sess = tf.Session())
  46. {
  47. // init queue
  48. init.run();
  49. // pop out first element and push back calculated y
  50. (int dequeued, _) = sess.run((x, inc));
  51. Assert.AreEqual(10, dequeued);
  52. (dequeued, _) = sess.run((x, inc));
  53. Assert.AreEqual(20, dequeued);
  54. (dequeued, _) = sess.run((x, inc));
  55. Assert.AreEqual(11, dequeued);
  56. (dequeued, _) = sess.run((x, inc));
  57. Assert.AreEqual(21, dequeued);
  58. // thread will hang or block if you run sess.run(x) again
  59. // until queue has more element.
  60. }
  61. }
  62. [TestMethod]
  63. public void PriorityQueue()
  64. {
  65. var queue = tf.PriorityQueue(3, tf.@string);
  66. var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
  67. var x = queue.dequeue();
  68. using (var sess = tf.Session())
  69. {
  70. init.run();
  71. var result = sess.run(x);
  72. Assert.AreEqual(result[0].GetInt64(), 2L);
  73. result = sess.run(x);
  74. Assert.AreEqual(result[0].GetInt64(), 3L);
  75. result = sess.run(x);
  76. Assert.AreEqual(result[0].GetInt64(), 4L);
  77. }
  78. }
  79. [TestMethod]
  80. public void RandomShuffleQueue()
  81. {
  82. var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
  83. var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
  84. var x = queue.dequeue();
  85. string results = "";
  86. using (var sess = tf.Session())
  87. {
  88. init.run();
  89. foreach(var i in range(9))
  90. results += (int)sess.run(x) + ".";
  91. // output in random order
  92. Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
  93. }
  94. }
  95. }
  96. }