Browse Source

tf.PriorityQueue, Protobuf.TextFormat

tags/v0.12
Oceania2018 6 years ago
parent
commit
a596dbe990
11 changed files with 328 additions and 146 deletions
  1. +27
    -0
      docs/source/Queue.md
  2. +21
    -2
      src/TensorFlowNET.Core/APIs/tf.queue.cs
  3. +66
    -0
      src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs
  4. +4
    -6
      src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs
  5. +28
    -0
      src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs
  6. +16
    -0
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  7. +3
    -2
      src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
  8. +123
    -133
      src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config
  9. +4
    -3
      src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs
  10. +14
    -0
      src/TensorFlowNET.Models/TensorFlowNET.Models.csproj
  11. +22
    -0
      test/TensorFlowNET.UnitTest/QueueTest.cs

+ 27
- 0
docs/source/Queue.md View File

@@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding

A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument.

```csharp
[TestMethod]
public void PriorityQueue()
{
var queue = tf.PriorityQueue(3, tf.@string);
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
var x = queue.dequeue();

using (var sess = tf.Session())
{
init.run();

// output will 2, 3, 4
var result = sess.run(x);
Assert.AreEqual(result[0].GetInt64(), 2L);

result = sess.run(x);
Assert.AreEqual(result[0].GetInt64(), 3L);

result = sess.run(x);
Assert.AreEqual(result[0].GetInt64(), 4L);
}
}
```



#### RandomShuffleQueue

A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument.


+ 21
- 2
src/TensorFlowNET.Core/APIs/tf.queue.cs View File

@@ -50,7 +50,7 @@ namespace Tensorflow
string shared_name = null,
string name = "padding_fifo_queue")
=> new PaddingFIFOQueue(capacity,
new [] { dtype },
new[] { dtype },
new[] { shape },
shared_name: shared_name,
name: name);
@@ -86,7 +86,26 @@ namespace Tensorflow
=> new FIFOQueue(capacity,
new[] { dtype },
new[] { shape ?? new TensorShape() },
new[] { name },
shared_name: shared_name,
name: name);

/// <summary>
/// Creates a queue that dequeues elements in a first-in first-out order.
/// </summary>
/// <param name="capacity"></param>
/// <param name="dtype"></param>
/// <param name="shape"></param>
/// <param name="shared_name"></param>
/// <param name="name"></param>
/// <returns></returns>
public PriorityQueue PriorityQueue(int capacity,
TF_DataType dtype,
TensorShape shape = null,
string shared_name = null,
string name = "priority_queue")
=> new PriorityQueue(capacity,
new[] { dtype },
new[] { shape ?? new TensorShape() },
shared_name: shared_name,
name: name);
}


+ 66
- 0
src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs View File

@@ -0,0 +1,66 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Queues
{
public class PriorityQueue : QueueBase
{
public PriorityQueue(int capacity,
TF_DataType[] dtypes,
TensorShape[] shapes,
string[] names = null,
string shared_name = null,
string name = "priority_queue")
: base(dtypes: dtypes, shapes: shapes, names: names)
{
_queue_ref = gen_data_flow_ops.priority_queue_v2(
component_types: dtypes,
shapes: shapes,
capacity: capacity,
shared_name: shared_name,
name: name);

_name = _queue_ref.op.name.Split('/').Last();

var dtypes1 = dtypes.ToList();
dtypes1.Insert(0, TF_DataType.TF_INT64);
_dtypes = dtypes1.ToArray();

var shapes1 = shapes.ToList();
shapes1.Insert(0, new TensorShape());
_shapes = shapes1.ToArray();
}

public Operation enqueue_many<T>(long[] indexes, T[] vals, string name = null)
{
return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope =>
{
var vals_tensor1 = _check_enqueue_dtypes(indexes);
var vals_tensor2 = _check_enqueue_dtypes(vals);

var tensors = new List<Tensor>();
tensors.AddRange(vals_tensor1);
tensors.AddRange(vals_tensor2);

return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope);
});
}

public Tensor[] dequeue(string name = null)
{
Tensor[] ret;
if (name == null)
name = $"{_name}_Dequeue";

if (_queue_ref.dtype == TF_DataType.TF_RESOURCE)
ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name);
else
ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name);

return ret;
}
}
}

+ 4
- 6
src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Queues
});
}

private Tensor[] _check_enqueue_dtypes(object vals)
protected Tensor[] _check_enqueue_dtypes(object vals)
{
var tensors = new List<Tensor>();

@@ -56,12 +56,10 @@ namespace Tensorflow.Queues
}
break;

case int[] vals1:
tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0"));
break;

default:
throw new NotImplementedException("");
var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0];
tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0"));
break;
}

return tensors.ToArray();


+ 28
- 0
src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Queues
{
public class RandomShuffleQueue : QueueBase
{
public RandomShuffleQueue(int capacity,
TF_DataType[] dtypes,
TensorShape[] shapes,
string[] names = null,
string shared_name = null,
string name = "randomshuffle_fifo_queue")
: base(dtypes: dtypes, shapes: shapes, names: names)
{
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
component_types: dtypes,
shapes: shapes,
capacity: capacity,
shared_name: shared_name,
name: name);

_name = _queue_ref.op.name.Split('/').Last();
}
}
}

+ 16
- 0
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -77,6 +77,22 @@ namespace Tensorflow
return _op.output;
}

public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,
int capacity = -1, string container = "", string shared_name = "",
string name = null)
{
var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new
{
component_types,
shapes,
capacity,
container,
shared_name
});

return _op.output;
}

public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
{
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new


+ 3
- 2
src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs View File

@@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection
{
ImageResizerBuilder _image_resizer_builder;
FasterRCNNFeatureExtractor _feature_extractor;
AnchorGeneratorBuilder anchor_generator_builder;
AnchorGeneratorBuilder _anchor_generator_builder;

public ModelBuilder()
{
_image_resizer_builder = new ImageResizerBuilder();
_anchor_generator_builder = new AnchorGeneratorBuilder();
}

/// <summary>
@@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection
inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate);

var number_of_stages = frcnn_config.NumberOfStages;
var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator);
var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator);
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;

return new FasterRCNNMetaArch(new FasterRCNNInitArgs


+ 123
- 133
src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config View File

@@ -1,143 +1,133 @@
{
"model": {
"fasterRcnn": {
"numClasses": 20,
"imageResizer": {
"keepAspectRatioResizer": {
"minDimension": 600,
"maxDimension": 1024
# Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.

model {
faster_rcnn {
num_classes: 20
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet101'
first_stage_features_stride: 16
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
},
"featureExtractor": {
"type": "faster_rcnn_resnet101",
"firstStageFeaturesStride": 16
},
"firstStageAnchorGenerator": {
"gridAnchorGenerator": {
"heightStride": 16,
"widthStride": 16,
"scales": [
0.25,
0.5,
1.0,
2.0
],
"aspectRatios": [
0.5,
1.0,
2.0
]
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
},
"firstStageBoxPredictorConvHyperparams": {
"op": "CONV",
"regularizer": {
"l2Regularizer": {
"weight": 0.0
}
},
"initializer": {
"truncatedNormalInitializer": {
"stddev": 0.009999999776482582
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
first_stage_max_proposals: 300
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
}
},
"firstStageNmsScoreThreshold": 0.0,
"firstStageNmsIouThreshold": 0.699999988079071,
"firstStageMaxProposals": 300,
"firstStageLocalizationLossWeight": 2.0,
"firstStageObjectnessLossWeight": 1.0,
"initialCropSize": 14,
"maxpoolKernelSize": 2,
"maxpoolStride": 2,
"secondStageBoxPredictor": {
"maskRcnnBoxPredictor": {
"fcHyperparams": {
"op": "FC",
"regularizer": {
"l2Regularizer": {
"weight": 0.0
}
},
"initializer": {
"varianceScalingInitializer": {
"factor": 1.0,
"uniform": true,
"mode": "FAN_AVG"
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
},
"useDropout": false,
"dropoutKeepProbability": 1.0
}
}
},
"secondStagePostProcessing": {
"batchNonMaxSuppression": {
"scoreThreshold": 0.0,
"iouThreshold": 0.6000000238418579,
"maxDetectionsPerClass": 100,
"maxTotalDetections": 300
},
"scoreConverter": "SOFTMAX"
},
"secondStageLocalizationLossWeight": 2.0,
"secondStageClassificationLossWeight": 1.0
}
}
},
"trainConfig": {
"batchSize": 1,
"dataAugmentationOptions": [
{
"randomHorizontalFlip": {}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
],
"optimizer": {
"momentumOptimizer": {
"learningRate": {
"manualStepLearningRate": {
"initialLearningRate": 9.999999747378752e-05,
"schedule": [
{
"step": 500000,
"learningRate": 9.999999747378752e-06
},
{
"step": 700000,
"learningRate": 9.999999974752427e-07
}
]
}
},
"momentumOptimizerValue": 0.8999999761581421
},
"useMovingAverage": false
},
"gradientClippingByNorm": 10.0,
"fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt",
"fromDetectionCheckpoint": true,
"numSteps": 800000
},
"trainInputReader": {
"labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt",
"tfRecordInputReader": {
"inputPath": [
"D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record"
]
score_converter: SOFTMAX
}
},
"evalConfig": {
"numExamples": 4952
},
"evalInputReader": [
{
"labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt",
"shuffle": false,
"numReaders": 1,
"tfRecordInputReader": {
"inputPath": [
"D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record"
]
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}

train_config: {
batch_size: 1
optimizer {
momentum_optimizer: {
learning_rate: {
manual_step_learning_rate {
initial_learning_rate: 0.0001
schedule {
step: 500000
learning_rate: .00001
}
schedule {
step: 700000
learning_rate: .000001
}
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
from_detection_checkpoint: true
num_steps: 800000
data_augmentation_options {
random_horizontal_flip {
}
]
}
}
}

train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt"
}

eval_config: {
num_examples: 4952
}

eval_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt"
shuffle: false
num_readers: 1
}

+ 4
- 3
src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs View File

@@ -1,4 +1,5 @@
using System;
using Protobuf.Text;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
@@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils
{
public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path)
{
var json = File.ReadAllText(pipeline_config_path);
var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json);
var config = File.ReadAllText(pipeline_config_path);
var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config);

return pipeline_config;
}


+ 14
- 0
src/TensorFlowNET.Models/TensorFlowNET.Models.csproj View File

@@ -4,6 +4,16 @@
<TargetFramework>netcoreapp2.2</TargetFramework>
<AssemblyName>TensorFlow.Models</AssemblyName>
<RootNamespace>Tensorflow.Models</RootNamespace>
<Version>0.0.1</Version>
<Authors>Haiping Chen</Authors>
<Company>SciSharp STACK</Company>
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl>
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
<RepositoryType>git</RepositoryType>
<PackageTags>TensorFlow</PackageTags>
<Description>Models and examples built with TensorFlow.</Description>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<Copyright>Apache2</Copyright>
</PropertyGroup>

<ItemGroup>
@@ -16,6 +26,10 @@
</Content>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Protobuf.Text" Version="0.3.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>


+ 22
- 0
test/TensorFlowNET.UnitTest/QueueTest.cs View File

@@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest
// until queue has more element.
}
}

[TestMethod]
public void PriorityQueue()
{
var queue = tf.PriorityQueue(3, tf.@string);
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
var x = queue.dequeue();

using (var sess = tf.Session())
{
init.run();

var result = sess.run(x);
Assert.AreEqual(result[0].GetInt64(), 2L);

result = sess.run(x);
Assert.AreEqual(result[0].GetInt64(), 3L);

result = sess.run(x);
Assert.AreEqual(result[0].GetInt64(), 4L);
}
}
}
}

Loading…
Cancel
Save