From 8b0e5cfca2d38fcf71d3a9719a564af8e9276a9f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 24 Sep 2019 04:23:53 -0500 Subject: [PATCH] added RandomShuffleQueue and docs updated. --- docs/source/Queue.md | 48 +++++++++++++++++++ src/TensorFlowNET.Core/APIs/tf.queue.cs | 15 ++++++ .../Operations/Queues/FIFOQueue.cs | 18 ++++++- .../Operations/Queues/PaddingFIFOQueue.cs | 18 ++++++- .../Operations/Queues/PriorityQueue.cs | 18 ++++++- .../Operations/Queues/QueueBase.cs | 18 ++++++- .../Operations/Queues/RandomShuffleQueue.cs | 35 ++++++++++++-- .../Operations/gen_data_flow_ops.cs | 19 ++++++++ test/TensorFlowNET.UnitTest/QueueTest.cs | 20 ++++++++ 9 files changed, 202 insertions(+), 7 deletions(-) diff --git a/docs/source/Queue.md b/docs/source/Queue.md index b846278b..7f137fb3 100644 --- a/docs/source/Queue.md +++ b/docs/source/Queue.md @@ -58,6 +58,32 @@ Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQue A FIFOQueue that supports batching variable-sized tensors by padding. A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are described by the `shapes` argument. +```chsarp +[TestMethod] +public void PaddingFIFOQueue() +{ + var numbers = tf.placeholder(tf.int32); + var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1)); + var enqueue = queue.enqueue(numbers); + var dequeue_many = queue.dequeue_many(n: 3); + + using(var sess = tf.Session()) + { + sess.run(enqueue, (numbers, new[] { 1 })); + sess.run(enqueue, (numbers, new[] { 2, 3 })); + sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); + + var result = sess.run(dequeue_many[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray())); + } +} +``` + + + #### PriorityQueue A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. @@ -93,6 +119,28 @@ public void PriorityQueue() A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. +```csharp +[TestMethod] +public void RandomShuffleQueue() +{ + var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32); + var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var x = queue.dequeue(); + + string results = ""; + using (var sess = tf.Session()) + { + init.run(); + + foreach(var i in range(9)) + results += (int)sess.run(x) + "."; + + // output in random order + // 1.2.3.4.5.6.7.8.9. + } +} +``` + Queue methods must run on the same device as the queue. `FIFOQueue` and `RandomShuffleQueue` are important TensorFlow objects for computing tensor asynchronously in a graph. For example, a typical input architecture is to use a `RandomShuffleQueue` to prepare inputs for training a model: diff --git a/src/TensorFlowNET.Core/APIs/tf.queue.cs b/src/TensorFlowNET.Core/APIs/tf.queue.cs index 1a9641b4..91947e5b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.queue.cs +++ b/src/TensorFlowNET.Core/APIs/tf.queue.cs @@ -108,5 +108,20 @@ namespace Tensorflow new[] { shape ?? new TensorShape() }, shared_name: shared_name, name: name); + + public RandomShuffleQueue RandomShuffleQueue(int capacity, + int min_after_dequeue, + TF_DataType dtype, + TensorShape shape = null, + int? seed = null, + string shared_name = null, + string name = "random_shuffle_queue") + => new RandomShuffleQueue(capacity, + min_after_dequeue: min_after_dequeue, + new[] { dtype }, + new[] { shape ?? new TensorShape() }, + seed: seed, + shared_name: shared_name, + name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs index fd4aa13f..b4d2e638 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs @@ -1,4 +1,20 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs index 3869f3b0..d8b93ff2 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs @@ -1,4 +1,20 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs index b41e1a0c..7420c017 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs @@ -1,4 +1,20 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs index 38821d9d..b420d2c9 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs @@ -1,4 +1,20 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs index 5765f081..6846f478 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs @@ -1,24 +1,53 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace Tensorflow.Queues { + /// + /// Create a queue that dequeues elements in a random order. + /// public class RandomShuffleQueue : QueueBase { public RandomShuffleQueue(int capacity, + int min_after_dequeue, TF_DataType[] dtypes, TensorShape[] shapes, string[] names = null, + int? seed = null, string shared_name = null, - string name = "randomshuffle_fifo_queue") + string name = "random_shuffle_queue") : base(dtypes: dtypes, shapes: shapes, names: names) { - _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( + var(seed1, seed2) = random_seed.get_seed(seed); + if (!seed1.HasValue && !seed2.HasValue) + (seed1, seed2) = (0, 0); + + + _queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( component_types: dtypes, shapes: shapes, capacity: capacity, + min_after_dequeue: min_after_dequeue, + seed: seed1.Value, + seed2: seed2.Value, shared_name: shared_name, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index b752268f..4e5bd1f6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -93,6 +93,25 @@ namespace Tensorflow return _op.output; } + public static Tensor random_shuffle_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, int min_after_dequeue = 0, int seed = 0, int seed2 = 0, + string container = "", string shared_name = "", string name = null) + { + var _op = _op_def_lib._apply_op_helper("RandomShuffleQueueV2", name, new + { + component_types, + shapes, + capacity, + min_after_dequeue, + seed, + seed2, + container, + shared_name + }); + + return _op.output; + } + public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) { var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new diff --git a/test/TensorFlowNET.UnitTest/QueueTest.cs b/test/TensorFlowNET.UnitTest/QueueTest.cs index d546d961..731635b7 100644 --- a/test/TensorFlowNET.UnitTest/QueueTest.cs +++ b/test/TensorFlowNET.UnitTest/QueueTest.cs @@ -92,5 +92,25 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(result[0].GetInt64(), 4L); } } + + [TestMethod] + public void RandomShuffleQueue() + { + var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32); + var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var x = queue.dequeue(); + + string results = ""; + using (var sess = tf.Session()) + { + init.run(); + + foreach(var i in range(9)) + results += (int)sess.run(x) + "."; + + // output in random order + Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); + } + } } }