diff --git a/src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs
new file mode 100644
index 00000000..c5093dfa
--- /dev/null
+++ b/src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Models.ObjectDetection
+{
+ public class GridAnchorGenerator : Core.AnchorGenerator
+ {
+ public GridAnchorGenerator(float[] scales = null)
+ {
+ if (scales == null)
+ scales = new[] { 0.5f, 1.0f, 2.0f };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs
new file mode 100644
index 00000000..f220bccd
--- /dev/null
+++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Models.ObjectDetection.Protos;
+using static Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator;
+
+namespace Tensorflow.Models.ObjectDetection
+{
+ public class AnchorGeneratorBuilder
+ {
+ public AnchorGeneratorBuilder()
+ {
+
+ }
+
+ public GridAnchorGenerator build(AnchorGenerator anchor_generator_config)
+ {
+ if(anchor_generator_config.AnchorGeneratorOneofCase == AnchorGeneratorOneofOneofCase.GridAnchorGenerator)
+ {
+ var grid_anchor_generator_config = anchor_generator_config.GridAnchorGenerator;
+ return new GridAnchorGenerator(scales: grid_anchor_generator_config.Scales.Select(x => float.Parse(x.ToString())).ToArray());
+ }
+ throw new NotImplementedException("");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
index 1f678098..2f2d3d85 100644
--- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
+++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
@@ -10,6 +10,7 @@ namespace Tensorflow.Models.ObjectDetection
{
ImageResizerBuilder _image_resizer_builder;
FasterRCNNFeatureExtractor _feature_extractor;
+ AnchorGeneratorBuilder anchor_generator_builder;
public ModelBuilder()
{
@@ -46,8 +47,12 @@ namespace Tensorflow.Models.ObjectDetection
var num_classes = frcnn_config.NumClasses;
var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer);
- var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;
+ var feature_extractor = _build_faster_rcnn_feature_extractor(frcnn_config.FeatureExtractor, is_training,
+ inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate);
+
var number_of_stages = frcnn_config.NumberOfStages;
+ var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator);
+ var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;
return new FasterRCNNMetaArch(new FasterRCNNInitArgs
{
@@ -65,5 +70,19 @@ namespace Tensorflow.Models.ObjectDetection
{
throw new NotImplementedException("");
}
+
+ private FasterRCNNFeatureExtractor _build_faster_rcnn_feature_extractor(FasterRcnnFeatureExtractor feature_extractor_config,
+ bool is_training, bool reuse_weights = false, bool inplace_batchnorm_update = false)
+ {
+ if (inplace_batchnorm_update)
+ throw new ValueError("inplace batchnorm updates not supported.");
+ var feature_type = feature_extractor_config.Type;
+ var first_stage_features_stride = feature_extractor_config.FirstStageFeaturesStride;
+ var batch_norm_trainable = feature_extractor_config.BatchNormTrainable;
+
+ return new FasterRCNNResnet101FeatureExtractor(is_training, first_stage_features_stride,
+ batch_norm_trainable: batch_norm_trainable,
+ reuse_weights: reuse_weights);
+ }
}
}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs
new file mode 100644
index 00000000..af44ee3f
--- /dev/null
+++ b/src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Models.ObjectDetection.Core
+{
+ public class AnchorGenerator
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs
index bc1f0d46..bdcfae76 100644
--- a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs
+++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs
@@ -4,7 +4,28 @@ using System.Text;
namespace Tensorflow.Models.ObjectDetection
{
+ ///
+ /// Faster R-CNN Feature Extractor definition.
+ ///
public class FasterRCNNFeatureExtractor
{
+ bool _is_training;
+ int _first_stage_features_stride;
+ bool _reuse_weights = false;
+ float _weight_decay = 0.0f;
+ bool _train_batch_norm;
+
+ public FasterRCNNFeatureExtractor(bool is_training,
+ int first_stage_features_stride,
+ bool batch_norm_trainable = false,
+ bool reuse_weights = false,
+ float weight_decay = 0.0f)
+ {
+ _is_training = is_training;
+ _first_stage_features_stride = first_stage_features_stride;
+ _train_batch_norm = (batch_norm_trainable && is_training);
+ _reuse_weights = reuse_weights;
+ _weight_decay = weight_decay;
+ }
}
}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs
new file mode 100644
index 00000000..75e21ade
--- /dev/null
+++ b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs
@@ -0,0 +1,32 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using Tensorflow.Operations.Activation;
+using Tensorflow.Models.Slim.Nets;
+
+namespace Tensorflow.Models.ObjectDetection
+{
+ ///
+ /// Faster R-CNN Resnet 101 feature extractor implementation.
+ ///
+ public class FasterRCNNResnet101FeatureExtractor : FasterRCNNResnetV1FeatureExtractor
+ {
+ public FasterRCNNResnet101FeatureExtractor(bool is_training,
+ int first_stage_features_stride,
+ bool batch_norm_trainable = false,
+ bool reuse_weights = false,
+ float weight_decay = 0.0f,
+ IActivation activation_fn = null) : base("resnet_v1_101",
+ ResNetV1.resnet_v1_101,
+ is_training,
+ first_stage_features_stride,
+ batch_norm_trainable: batch_norm_trainable,
+ reuse_weights: reuse_weights,
+ weight_decay: weight_decay,
+ activation_fn: activation_fn)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs
new file mode 100644
index 00000000..e4a8351b
--- /dev/null
+++ b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using Tensorflow.Operations.Activation;
+
+namespace Tensorflow.Models.ObjectDetection
+{
+ public class FasterRCNNResnetV1FeatureExtractor : FasterRCNNFeatureExtractor
+ {
+ public FasterRCNNResnetV1FeatureExtractor(string architecture,
+ Action resnet_model,
+ bool is_training,
+ int first_stage_features_stride,
+ bool batch_norm_trainable = false,
+ bool reuse_weights = false,
+ float weight_decay = 0.0f,
+ IActivation activation_fn = null) : base(is_training,
+ first_stage_features_stride,
+ batch_norm_trainable: batch_norm_trainable,
+ reuse_weights: reuse_weights,
+ weight_decay: weight_decay)
+ {
+ if (activation_fn == null)
+ activation_fn = tf.nn.relu();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs b/src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs
new file mode 100644
index 00000000..1d949434
--- /dev/null
+++ b/src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Models.Slim.Nets
+{
+ public class ResNetV1
+ {
+ public static void resnet_v1_101()
+ {
+ throw new NotImplementedException("");
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs
index a31d21be..e70d5429 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs
@@ -73,7 +73,6 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection
public void PrepareData()
{
- throw new NotImplementedException();
}
public void Train(Session sess)