From a596dbe990ebc7026df1e25a282bb52c1dad21fa Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 22 Sep 2019 22:32:31 -0500 Subject: [PATCH] tf.PriorityQueue, Protobuf.TextFormat --- docs/source/Queue.md | 27 ++ src/TensorFlowNET.Core/APIs/tf.queue.cs | 23 +- .../Operations/Queues/PriorityQueue.cs | 66 +++++ .../Operations/Queues/QueueBase.cs | 10 +- .../Operations/Queues/RandomShuffleQueue.cs | 28 ++ .../Operations/gen_data_flow_ops.cs | 16 ++ .../ObjectDetection/Builders/ModelBuilder.cs | 5 +- .../Models/faster_rcnn_resnet101_voc07.config | 256 +++++++++--------- .../ObjectDetection/Utils/ConfigUtil.cs | 7 +- .../TensorFlowNET.Models.csproj | 14 + test/TensorFlowNET.UnitTest/QueueTest.cs | 22 ++ 11 files changed, 328 insertions(+), 146 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs create mode 100644 src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs diff --git a/docs/source/Queue.md b/docs/source/Queue.md index bd73fd5a..b846278b 100644 --- a/docs/source/Queue.md +++ b/docs/source/Queue.md @@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. +```csharp +[TestMethod] +public void PriorityQueue() +{ + var queue = tf.PriorityQueue(3, tf.@string); + var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); + var x = queue.dequeue(); + + using (var sess = tf.Session()) + { + init.run(); + + // output will 2, 3, 4 + var result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 2L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 3L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 4L); + } +} +``` + + + #### RandomShuffleQueue A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. diff --git a/src/TensorFlowNET.Core/APIs/tf.queue.cs b/src/TensorFlowNET.Core/APIs/tf.queue.cs index f81f5726..1a9641b4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.queue.cs +++ b/src/TensorFlowNET.Core/APIs/tf.queue.cs @@ -50,7 +50,7 @@ namespace Tensorflow string shared_name = null, string name = "padding_fifo_queue") => new PaddingFIFOQueue(capacity, - new [] { dtype }, + new[] { dtype }, new[] { shape }, shared_name: shared_name, name: name); @@ -86,7 +86,26 @@ namespace Tensorflow => new FIFOQueue(capacity, new[] { dtype }, new[] { shape ?? new TensorShape() }, - new[] { name }, + shared_name: shared_name, + name: name); + + /// + /// Creates a queue that dequeues elements in a first-in first-out order. + /// + /// + /// + /// + /// + /// + /// + public PriorityQueue PriorityQueue(int capacity, + TF_DataType dtype, + TensorShape shape = null, + string shared_name = null, + string name = "priority_queue") + => new PriorityQueue(capacity, + new[] { dtype }, + new[] { shape ?? new TensorShape() }, shared_name: shared_name, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs new file mode 100644 index 00000000..b41e1a0c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Queues +{ + public class PriorityQueue : QueueBase + { + public PriorityQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "priority_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.priority_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + + var dtypes1 = dtypes.ToList(); + dtypes1.Insert(0, TF_DataType.TF_INT64); + _dtypes = dtypes1.ToArray(); + + var shapes1 = shapes.ToList(); + shapes1.Insert(0, new TensorShape()); + _shapes = shapes1.ToArray(); + } + + public Operation enqueue_many(long[] indexes, T[] vals, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => + { + var vals_tensor1 = _check_enqueue_dtypes(indexes); + var vals_tensor2 = _check_enqueue_dtypes(vals); + + var tensors = new List(); + tensors.AddRange(vals_tensor1); + tensors.AddRange(vals_tensor2); + + return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); + }); + } + + public Tensor[] dequeue(string name = null) + { + Tensor[] ret; + if (name == null) + name = $"{_name}_Dequeue"; + + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); + else + ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); + + return ret; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs index 0eb5816d..38821d9d 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Queues }); } - private Tensor[] _check_enqueue_dtypes(object vals) + protected Tensor[] _check_enqueue_dtypes(object vals) { var tensors = new List(); @@ -56,12 +56,10 @@ namespace Tensorflow.Queues } break; - case int[] vals1: - tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0")); - break; - default: - throw new NotImplementedException(""); + var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; + tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); + break; } return tensors.ToArray(); diff --git a/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs new file mode 100644 index 00000000..5765f081 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Queues +{ + public class RandomShuffleQueue : QueueBase + { + public RandomShuffleQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "randomshuffle_fifo_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 4fd394d2..b752268f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -77,6 +77,22 @@ namespace Tensorflow return _op.output; } + public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) { var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs index 2f2d3d85..596a7532 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs @@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection { ImageResizerBuilder _image_resizer_builder; FasterRCNNFeatureExtractor _feature_extractor; - AnchorGeneratorBuilder anchor_generator_builder; + AnchorGeneratorBuilder _anchor_generator_builder; public ModelBuilder() { _image_resizer_builder = new ImageResizerBuilder(); + _anchor_generator_builder = new AnchorGeneratorBuilder(); } /// @@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); var number_of_stages = frcnn_config.NumberOfStages; - var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); + var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; return new FasterRCNNMetaArch(new FasterRCNNInitArgs diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config b/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config index d5ec5f38..7458f4a5 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config +++ b/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config @@ -1,143 +1,133 @@ -{ - "model": { - "fasterRcnn": { - "numClasses": 20, - "imageResizer": { - "keepAspectRatioResizer": { - "minDimension": 600, - "maxDimension": 1024 +# Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. +# Users should configure the fine_tune_checkpoint field in the train config as +# well as the label_map_path and input_path fields in the train_input_reader and +# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that +# should be configured. + +model { + faster_rcnn { + num_classes: 20 + image_resizer { + keep_aspect_ratio_resizer { + min_dimension: 600 + max_dimension: 1024 + } + } + feature_extractor { + type: 'faster_rcnn_resnet101' + first_stage_features_stride: 16 + } + first_stage_anchor_generator { + grid_anchor_generator { + scales: [0.25, 0.5, 1.0, 2.0] + aspect_ratios: [0.5, 1.0, 2.0] + height_stride: 16 + width_stride: 16 + } + } + first_stage_box_predictor_conv_hyperparams { + op: CONV + regularizer { + l2_regularizer { + weight: 0.0 } - }, - "featureExtractor": { - "type": "faster_rcnn_resnet101", - "firstStageFeaturesStride": 16 - }, - "firstStageAnchorGenerator": { - "gridAnchorGenerator": { - "heightStride": 16, - "widthStride": 16, - "scales": [ - 0.25, - 0.5, - 1.0, - 2.0 - ], - "aspectRatios": [ - 0.5, - 1.0, - 2.0 - ] + } + initializer { + truncated_normal_initializer { + stddev: 0.01 } - }, - "firstStageBoxPredictorConvHyperparams": { - "op": "CONV", - "regularizer": { - "l2Regularizer": { - "weight": 0.0 - } - }, - "initializer": { - "truncatedNormalInitializer": { - "stddev": 0.009999999776482582 + } + } + first_stage_nms_score_threshold: 0.0 + first_stage_nms_iou_threshold: 0.7 + first_stage_max_proposals: 300 + first_stage_localization_loss_weight: 2.0 + first_stage_objectness_loss_weight: 1.0 + initial_crop_size: 14 + maxpool_kernel_size: 2 + maxpool_stride: 2 + second_stage_box_predictor { + mask_rcnn_box_predictor { + use_dropout: false + dropout_keep_probability: 1.0 + fc_hyperparams { + op: FC + regularizer { + l2_regularizer { + weight: 0.0 + } } - } - }, - "firstStageNmsScoreThreshold": 0.0, - "firstStageNmsIouThreshold": 0.699999988079071, - "firstStageMaxProposals": 300, - "firstStageLocalizationLossWeight": 2.0, - "firstStageObjectnessLossWeight": 1.0, - "initialCropSize": 14, - "maxpoolKernelSize": 2, - "maxpoolStride": 2, - "secondStageBoxPredictor": { - "maskRcnnBoxPredictor": { - "fcHyperparams": { - "op": "FC", - "regularizer": { - "l2Regularizer": { - "weight": 0.0 - } - }, - "initializer": { - "varianceScalingInitializer": { - "factor": 1.0, - "uniform": true, - "mode": "FAN_AVG" - } + initializer { + variance_scaling_initializer { + factor: 1.0 + uniform: true + mode: FAN_AVG } - }, - "useDropout": false, - "dropoutKeepProbability": 1.0 + } } - }, - "secondStagePostProcessing": { - "batchNonMaxSuppression": { - "scoreThreshold": 0.0, - "iouThreshold": 0.6000000238418579, - "maxDetectionsPerClass": 100, - "maxTotalDetections": 300 - }, - "scoreConverter": "SOFTMAX" - }, - "secondStageLocalizationLossWeight": 2.0, - "secondStageClassificationLossWeight": 1.0 + } } - }, - "trainConfig": { - "batchSize": 1, - "dataAugmentationOptions": [ - { - "randomHorizontalFlip": {} + second_stage_post_processing { + batch_non_max_suppression { + score_threshold: 0.0 + iou_threshold: 0.6 + max_detections_per_class: 100 + max_total_detections: 300 } - ], - "optimizer": { - "momentumOptimizer": { - "learningRate": { - "manualStepLearningRate": { - "initialLearningRate": 9.999999747378752e-05, - "schedule": [ - { - "step": 500000, - "learningRate": 9.999999747378752e-06 - }, - { - "step": 700000, - "learningRate": 9.999999974752427e-07 - } - ] - } - }, - "momentumOptimizerValue": 0.8999999761581421 - }, - "useMovingAverage": false - }, - "gradientClippingByNorm": 10.0, - "fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt", - "fromDetectionCheckpoint": true, - "numSteps": 800000 - }, - "trainInputReader": { - "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", - "tfRecordInputReader": { - "inputPath": [ - "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record" - ] + score_converter: SOFTMAX } - }, - "evalConfig": { - "numExamples": 4952 - }, - "evalInputReader": [ - { - "labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", - "shuffle": false, - "numReaders": 1, - "tfRecordInputReader": { - "inputPath": [ - "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record" - ] + second_stage_localization_loss_weight: 2.0 + second_stage_classification_loss_weight: 1.0 + } +} + +train_config: { + batch_size: 1 + optimizer { + momentum_optimizer: { + learning_rate: { + manual_step_learning_rate { + initial_learning_rate: 0.0001 + schedule { + step: 500000 + learning_rate: .00001 + } + schedule { + step: 700000 + learning_rate: .000001 + } + } } + momentum_optimizer_value: 0.9 + } + use_moving_average: false + } + gradient_clipping_by_norm: 10.0 + fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" + from_detection_checkpoint: true + num_steps: 800000 + data_augmentation_options { + random_horizontal_flip { } - ] -} \ No newline at end of file + } +} + +train_input_reader: { + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record" + } + label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" +} + +eval_config: { + num_examples: 4952 +} + +eval_input_reader: { + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record" + } + label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" + shuffle: false + num_readers: 1 +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs index a8b3876e..2a6a672e 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs @@ -1,4 +1,5 @@ -using System; +using Protobuf.Text; +using System; using System.Collections.Generic; using System.IO; using System.Text; @@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils { public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) { - var json = File.ReadAllText(pipeline_config_path); - var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json); + var config = File.ReadAllText(pipeline_config_path); + var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config); return pipeline_config; } diff --git a/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj b/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj index 291c7c03..aae55be9 100644 --- a/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj +++ b/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj @@ -4,6 +4,16 @@ netcoreapp2.2 TensorFlow.Models Tensorflow.Models + 0.0.1 + Haiping Chen + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + https://github.com/SciSharp/TensorFlow.NET + git + TensorFlow + Models and examples built with TensorFlow. + true + Apache2 @@ -16,6 +26,10 @@ + + + + diff --git a/test/TensorFlowNET.UnitTest/QueueTest.cs b/test/TensorFlowNET.UnitTest/QueueTest.cs index 14afbae5..d546d961 100644 --- a/test/TensorFlowNET.UnitTest/QueueTest.cs +++ b/test/TensorFlowNET.UnitTest/QueueTest.cs @@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest // until queue has more element. } } + + [TestMethod] + public void PriorityQueue() + { + var queue = tf.PriorityQueue(3, tf.@string); + var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); + var x = queue.dequeue(); + + using (var sess = tf.Session()) + { + init.run(); + + var result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 2L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 3L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 4L); + } + } } }