@@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding | |||
A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. | |||
```csharp | |||
[TestMethod] | |||
public void PriorityQueue() | |||
{ | |||
var queue = tf.PriorityQueue(3, tf.@string); | |||
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||
var x = queue.dequeue(); | |||
using (var sess = tf.Session()) | |||
{ | |||
init.run(); | |||
// output will 2, 3, 4 | |||
var result = sess.run(x); | |||
Assert.AreEqual(result[0].GetInt64(), 2L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0].GetInt64(), 3L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0].GetInt64(), 4L); | |||
} | |||
} | |||
``` | |||
#### RandomShuffleQueue | |||
A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. | |||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
string shared_name = null, | |||
string name = "padding_fifo_queue") | |||
=> new PaddingFIFOQueue(capacity, | |||
new [] { dtype }, | |||
new[] { dtype }, | |||
new[] { shape }, | |||
shared_name: shared_name, | |||
name: name); | |||
@@ -86,7 +86,26 @@ namespace Tensorflow | |||
=> new FIFOQueue(capacity, | |||
new[] { dtype }, | |||
new[] { shape ?? new TensorShape() }, | |||
new[] { name }, | |||
shared_name: shared_name, | |||
name: name); | |||
/// <summary> | |||
/// Creates a queue that dequeues elements in a first-in first-out order. | |||
/// </summary> | |||
/// <param name="capacity"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="shape"></param> | |||
/// <param name="shared_name"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public PriorityQueue PriorityQueue(int capacity, | |||
TF_DataType dtype, | |||
TensorShape shape = null, | |||
string shared_name = null, | |||
string name = "priority_queue") | |||
=> new PriorityQueue(capacity, | |||
new[] { dtype }, | |||
new[] { shape ?? new TensorShape() }, | |||
shared_name: shared_name, | |||
name: name); | |||
} | |||
@@ -0,0 +1,66 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Queues | |||
{ | |||
public class PriorityQueue : QueueBase | |||
{ | |||
public PriorityQueue(int capacity, | |||
TF_DataType[] dtypes, | |||
TensorShape[] shapes, | |||
string[] names = null, | |||
string shared_name = null, | |||
string name = "priority_queue") | |||
: base(dtypes: dtypes, shapes: shapes, names: names) | |||
{ | |||
_queue_ref = gen_data_flow_ops.priority_queue_v2( | |||
component_types: dtypes, | |||
shapes: shapes, | |||
capacity: capacity, | |||
shared_name: shared_name, | |||
name: name); | |||
_name = _queue_ref.op.name.Split('/').Last(); | |||
var dtypes1 = dtypes.ToList(); | |||
dtypes1.Insert(0, TF_DataType.TF_INT64); | |||
_dtypes = dtypes1.ToArray(); | |||
var shapes1 = shapes.ToList(); | |||
shapes1.Insert(0, new TensorShape()); | |||
_shapes = shapes1.ToArray(); | |||
} | |||
public Operation enqueue_many<T>(long[] indexes, T[] vals, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => | |||
{ | |||
var vals_tensor1 = _check_enqueue_dtypes(indexes); | |||
var vals_tensor2 = _check_enqueue_dtypes(vals); | |||
var tensors = new List<Tensor>(); | |||
tensors.AddRange(vals_tensor1); | |||
tensors.AddRange(vals_tensor2); | |||
return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); | |||
}); | |||
} | |||
public Tensor[] dequeue(string name = null) | |||
{ | |||
Tensor[] ret; | |||
if (name == null) | |||
name = $"{_name}_Dequeue"; | |||
if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) | |||
ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); | |||
else | |||
ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); | |||
return ret; | |||
} | |||
} | |||
} |
@@ -42,7 +42,7 @@ namespace Tensorflow.Queues | |||
}); | |||
} | |||
private Tensor[] _check_enqueue_dtypes(object vals) | |||
protected Tensor[] _check_enqueue_dtypes(object vals) | |||
{ | |||
var tensors = new List<Tensor>(); | |||
@@ -56,12 +56,10 @@ namespace Tensorflow.Queues | |||
} | |||
break; | |||
case int[] vals1: | |||
tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0")); | |||
break; | |||
default: | |||
throw new NotImplementedException(""); | |||
var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; | |||
tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); | |||
break; | |||
} | |||
return tensors.ToArray(); | |||
@@ -0,0 +1,28 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow.Queues | |||
{ | |||
public class RandomShuffleQueue : QueueBase | |||
{ | |||
public RandomShuffleQueue(int capacity, | |||
TF_DataType[] dtypes, | |||
TensorShape[] shapes, | |||
string[] names = null, | |||
string shared_name = null, | |||
string name = "randomshuffle_fifo_queue") | |||
: base(dtypes: dtypes, shapes: shapes, names: names) | |||
{ | |||
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( | |||
component_types: dtypes, | |||
shapes: shapes, | |||
capacity: capacity, | |||
shared_name: shared_name, | |||
name: name); | |||
_name = _queue_ref.op.name.Split('/').Last(); | |||
} | |||
} | |||
} |
@@ -77,6 +77,22 @@ namespace Tensorflow | |||
return _op.output; | |||
} | |||
public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||
int capacity = -1, string container = "", string shared_name = "", | |||
string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new | |||
{ | |||
component_types, | |||
shapes, | |||
capacity, | |||
container, | |||
shared_name | |||
}); | |||
return _op.output; | |||
} | |||
public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | |||
@@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
ImageResizerBuilder _image_resizer_builder; | |||
FasterRCNNFeatureExtractor _feature_extractor; | |||
AnchorGeneratorBuilder anchor_generator_builder; | |||
AnchorGeneratorBuilder _anchor_generator_builder; | |||
public ModelBuilder() | |||
{ | |||
_image_resizer_builder = new ImageResizerBuilder(); | |||
_anchor_generator_builder = new AnchorGeneratorBuilder(); | |||
} | |||
/// <summary> | |||
@@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | |||
var number_of_stages = frcnn_config.NumberOfStages; | |||
var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||
var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||
return new FasterRCNNMetaArch(new FasterRCNNInitArgs | |||
@@ -1,143 +1,133 @@ | |||
{ | |||
"model": { | |||
"fasterRcnn": { | |||
"numClasses": 20, | |||
"imageResizer": { | |||
"keepAspectRatioResizer": { | |||
"minDimension": 600, | |||
"maxDimension": 1024 | |||
# Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. | |||
# Users should configure the fine_tune_checkpoint field in the train config as | |||
# well as the label_map_path and input_path fields in the train_input_reader and | |||
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that | |||
# should be configured. | |||
model { | |||
faster_rcnn { | |||
num_classes: 20 | |||
image_resizer { | |||
keep_aspect_ratio_resizer { | |||
min_dimension: 600 | |||
max_dimension: 1024 | |||
} | |||
} | |||
feature_extractor { | |||
type: 'faster_rcnn_resnet101' | |||
first_stage_features_stride: 16 | |||
} | |||
first_stage_anchor_generator { | |||
grid_anchor_generator { | |||
scales: [0.25, 0.5, 1.0, 2.0] | |||
aspect_ratios: [0.5, 1.0, 2.0] | |||
height_stride: 16 | |||
width_stride: 16 | |||
} | |||
} | |||
first_stage_box_predictor_conv_hyperparams { | |||
op: CONV | |||
regularizer { | |||
l2_regularizer { | |||
weight: 0.0 | |||
} | |||
}, | |||
"featureExtractor": { | |||
"type": "faster_rcnn_resnet101", | |||
"firstStageFeaturesStride": 16 | |||
}, | |||
"firstStageAnchorGenerator": { | |||
"gridAnchorGenerator": { | |||
"heightStride": 16, | |||
"widthStride": 16, | |||
"scales": [ | |||
0.25, | |||
0.5, | |||
1.0, | |||
2.0 | |||
], | |||
"aspectRatios": [ | |||
0.5, | |||
1.0, | |||
2.0 | |||
] | |||
} | |||
initializer { | |||
truncated_normal_initializer { | |||
stddev: 0.01 | |||
} | |||
}, | |||
"firstStageBoxPredictorConvHyperparams": { | |||
"op": "CONV", | |||
"regularizer": { | |||
"l2Regularizer": { | |||
"weight": 0.0 | |||
} | |||
}, | |||
"initializer": { | |||
"truncatedNormalInitializer": { | |||
"stddev": 0.009999999776482582 | |||
} | |||
} | |||
first_stage_nms_score_threshold: 0.0 | |||
first_stage_nms_iou_threshold: 0.7 | |||
first_stage_max_proposals: 300 | |||
first_stage_localization_loss_weight: 2.0 | |||
first_stage_objectness_loss_weight: 1.0 | |||
initial_crop_size: 14 | |||
maxpool_kernel_size: 2 | |||
maxpool_stride: 2 | |||
second_stage_box_predictor { | |||
mask_rcnn_box_predictor { | |||
use_dropout: false | |||
dropout_keep_probability: 1.0 | |||
fc_hyperparams { | |||
op: FC | |||
regularizer { | |||
l2_regularizer { | |||
weight: 0.0 | |||
} | |||
} | |||
} | |||
}, | |||
"firstStageNmsScoreThreshold": 0.0, | |||
"firstStageNmsIouThreshold": 0.699999988079071, | |||
"firstStageMaxProposals": 300, | |||
"firstStageLocalizationLossWeight": 2.0, | |||
"firstStageObjectnessLossWeight": 1.0, | |||
"initialCropSize": 14, | |||
"maxpoolKernelSize": 2, | |||
"maxpoolStride": 2, | |||
"secondStageBoxPredictor": { | |||
"maskRcnnBoxPredictor": { | |||
"fcHyperparams": { | |||
"op": "FC", | |||
"regularizer": { | |||
"l2Regularizer": { | |||
"weight": 0.0 | |||
} | |||
}, | |||
"initializer": { | |||
"varianceScalingInitializer": { | |||
"factor": 1.0, | |||
"uniform": true, | |||
"mode": "FAN_AVG" | |||
} | |||
initializer { | |||
variance_scaling_initializer { | |||
factor: 1.0 | |||
uniform: true | |||
mode: FAN_AVG | |||
} | |||
}, | |||
"useDropout": false, | |||
"dropoutKeepProbability": 1.0 | |||
} | |||
} | |||
}, | |||
"secondStagePostProcessing": { | |||
"batchNonMaxSuppression": { | |||
"scoreThreshold": 0.0, | |||
"iouThreshold": 0.6000000238418579, | |||
"maxDetectionsPerClass": 100, | |||
"maxTotalDetections": 300 | |||
}, | |||
"scoreConverter": "SOFTMAX" | |||
}, | |||
"secondStageLocalizationLossWeight": 2.0, | |||
"secondStageClassificationLossWeight": 1.0 | |||
} | |||
} | |||
}, | |||
"trainConfig": { | |||
"batchSize": 1, | |||
"dataAugmentationOptions": [ | |||
{ | |||
"randomHorizontalFlip": {} | |||
second_stage_post_processing { | |||
batch_non_max_suppression { | |||
score_threshold: 0.0 | |||
iou_threshold: 0.6 | |||
max_detections_per_class: 100 | |||
max_total_detections: 300 | |||
} | |||
], | |||
"optimizer": { | |||
"momentumOptimizer": { | |||
"learningRate": { | |||
"manualStepLearningRate": { | |||
"initialLearningRate": 9.999999747378752e-05, | |||
"schedule": [ | |||
{ | |||
"step": 500000, | |||
"learningRate": 9.999999747378752e-06 | |||
}, | |||
{ | |||
"step": 700000, | |||
"learningRate": 9.999999974752427e-07 | |||
} | |||
] | |||
} | |||
}, | |||
"momentumOptimizerValue": 0.8999999761581421 | |||
}, | |||
"useMovingAverage": false | |||
}, | |||
"gradientClippingByNorm": 10.0, | |||
"fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt", | |||
"fromDetectionCheckpoint": true, | |||
"numSteps": 800000 | |||
}, | |||
"trainInputReader": { | |||
"labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||
"tfRecordInputReader": { | |||
"inputPath": [ | |||
"D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record" | |||
] | |||
score_converter: SOFTMAX | |||
} | |||
}, | |||
"evalConfig": { | |||
"numExamples": 4952 | |||
}, | |||
"evalInputReader": [ | |||
{ | |||
"labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||
"shuffle": false, | |||
"numReaders": 1, | |||
"tfRecordInputReader": { | |||
"inputPath": [ | |||
"D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record" | |||
] | |||
second_stage_localization_loss_weight: 2.0 | |||
second_stage_classification_loss_weight: 1.0 | |||
} | |||
} | |||
train_config: { | |||
batch_size: 1 | |||
optimizer { | |||
momentum_optimizer: { | |||
learning_rate: { | |||
manual_step_learning_rate { | |||
initial_learning_rate: 0.0001 | |||
schedule { | |||
step: 500000 | |||
learning_rate: .00001 | |||
} | |||
schedule { | |||
step: 700000 | |||
learning_rate: .000001 | |||
} | |||
} | |||
} | |||
momentum_optimizer_value: 0.9 | |||
} | |||
use_moving_average: false | |||
} | |||
gradient_clipping_by_norm: 10.0 | |||
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" | |||
from_detection_checkpoint: true | |||
num_steps: 800000 | |||
data_augmentation_options { | |||
random_horizontal_flip { | |||
} | |||
] | |||
} | |||
} | |||
} | |||
train_input_reader: { | |||
tf_record_input_reader { | |||
input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record" | |||
} | |||
label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||
} | |||
eval_config: { | |||
num_examples: 4952 | |||
} | |||
eval_input_reader: { | |||
tf_record_input_reader { | |||
input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record" | |||
} | |||
label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||
shuffle: false | |||
num_readers: 1 | |||
} |
@@ -1,4 +1,5 @@ | |||
using System; | |||
using Protobuf.Text; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
@@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils | |||
{ | |||
public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) | |||
{ | |||
var json = File.ReadAllText(pipeline_config_path); | |||
var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json); | |||
var config = File.ReadAllText(pipeline_config_path); | |||
var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config); | |||
return pipeline_config; | |||
} | |||
@@ -4,6 +4,16 @@ | |||
<TargetFramework>netcoreapp2.2</TargetFramework> | |||
<AssemblyName>TensorFlow.Models</AssemblyName> | |||
<RootNamespace>Tensorflow.Models</RootNamespace> | |||
<Version>0.0.1</Version> | |||
<Authors>Haiping Chen</Authors> | |||
<Company>SciSharp STACK</Company> | |||
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl> | |||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||
<RepositoryType>git</RepositoryType> | |||
<PackageTags>TensorFlow</PackageTags> | |||
<Description>Models and examples built with TensorFlow.</Description> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
<Copyright>Apache2</Copyright> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
@@ -16,6 +26,10 @@ | |||
</Content> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Protobuf.Text" Version="0.3.1" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
</ItemGroup> | |||
@@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest | |||
// until queue has more element. | |||
} | |||
} | |||
[TestMethod] | |||
public void PriorityQueue() | |||
{ | |||
var queue = tf.PriorityQueue(3, tf.@string); | |||
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||
var x = queue.dequeue(); | |||
using (var sess = tf.Session()) | |||
{ | |||
init.run(); | |||
var result = sess.run(x); | |||
Assert.AreEqual(result[0].GetInt64(), 2L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0].GetInt64(), 3L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0].GetInt64(), 4L); | |||
} | |||
} | |||
} | |||
} |