@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Models.ObjectDetection | |||||
{ | |||||
public class GridAnchorGenerator : Core.AnchorGenerator | |||||
{ | |||||
public GridAnchorGenerator(float[] scales = null) | |||||
{ | |||||
if (scales == null) | |||||
scales = new[] { 0.5f, 1.0f, 2.0f }; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Models.ObjectDetection.Protos; | |||||
using static Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator; | |||||
namespace Tensorflow.Models.ObjectDetection | |||||
{ | |||||
public class AnchorGeneratorBuilder | |||||
{ | |||||
public AnchorGeneratorBuilder() | |||||
{ | |||||
} | |||||
public GridAnchorGenerator build(AnchorGenerator anchor_generator_config) | |||||
{ | |||||
if(anchor_generator_config.AnchorGeneratorOneofCase == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) | |||||
{ | |||||
var grid_anchor_generator_config = anchor_generator_config.GridAnchorGenerator; | |||||
return new GridAnchorGenerator(scales: grid_anchor_generator_config.Scales.Select(x => float.Parse(x.ToString())).ToArray()); | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -10,6 +10,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
{ | { | ||||
ImageResizerBuilder _image_resizer_builder; | ImageResizerBuilder _image_resizer_builder; | ||||
FasterRCNNFeatureExtractor _feature_extractor; | FasterRCNNFeatureExtractor _feature_extractor; | ||||
AnchorGeneratorBuilder anchor_generator_builder; | |||||
public ModelBuilder() | public ModelBuilder() | ||||
{ | { | ||||
@@ -46,8 +47,12 @@ namespace Tensorflow.Models.ObjectDetection | |||||
var num_classes = frcnn_config.NumClasses; | var num_classes = frcnn_config.NumClasses; | ||||
var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); | var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); | ||||
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||||
var feature_extractor = _build_faster_rcnn_feature_extractor(frcnn_config.FeatureExtractor, is_training, | |||||
inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | |||||
var number_of_stages = frcnn_config.NumberOfStages; | var number_of_stages = frcnn_config.NumberOfStages; | ||||
var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||||
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||||
return new FasterRCNNMetaArch(new FasterRCNNInitArgs | return new FasterRCNNMetaArch(new FasterRCNNInitArgs | ||||
{ | { | ||||
@@ -65,5 +70,19 @@ namespace Tensorflow.Models.ObjectDetection | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
private FasterRCNNFeatureExtractor _build_faster_rcnn_feature_extractor(FasterRcnnFeatureExtractor feature_extractor_config, | |||||
bool is_training, bool reuse_weights = false, bool inplace_batchnorm_update = false) | |||||
{ | |||||
if (inplace_batchnorm_update) | |||||
throw new ValueError("inplace batchnorm updates not supported."); | |||||
var feature_type = feature_extractor_config.Type; | |||||
var first_stage_features_stride = feature_extractor_config.FirstStageFeaturesStride; | |||||
var batch_norm_trainable = feature_extractor_config.BatchNormTrainable; | |||||
return new FasterRCNNResnet101FeatureExtractor(is_training, first_stage_features_stride, | |||||
batch_norm_trainable: batch_norm_trainable, | |||||
reuse_weights: reuse_weights); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Models.ObjectDetection.Core | |||||
{ | |||||
public class AnchorGenerator | |||||
{ | |||||
} | |||||
} |
@@ -4,7 +4,28 @@ using System.Text; | |||||
namespace Tensorflow.Models.ObjectDetection | namespace Tensorflow.Models.ObjectDetection | ||||
{ | { | ||||
/// <summary> | |||||
/// Faster R-CNN Feature Extractor definition. | |||||
/// </summary> | |||||
public class FasterRCNNFeatureExtractor | public class FasterRCNNFeatureExtractor | ||||
{ | { | ||||
bool _is_training; | |||||
int _first_stage_features_stride; | |||||
bool _reuse_weights = false; | |||||
float _weight_decay = 0.0f; | |||||
bool _train_batch_norm; | |||||
public FasterRCNNFeatureExtractor(bool is_training, | |||||
int first_stage_features_stride, | |||||
bool batch_norm_trainable = false, | |||||
bool reuse_weights = false, | |||||
float weight_decay = 0.0f) | |||||
{ | |||||
_is_training = is_training; | |||||
_first_stage_features_stride = first_stage_features_stride; | |||||
_train_batch_norm = (batch_norm_trainable && is_training); | |||||
_reuse_weights = reuse_weights; | |||||
_weight_decay = weight_decay; | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,32 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Operations.Activation; | |||||
using Tensorflow.Models.Slim.Nets; | |||||
namespace Tensorflow.Models.ObjectDetection | |||||
{ | |||||
/// <summary> | |||||
/// Faster R-CNN Resnet 101 feature extractor implementation. | |||||
/// </summary> | |||||
public class FasterRCNNResnet101FeatureExtractor : FasterRCNNResnetV1FeatureExtractor | |||||
{ | |||||
public FasterRCNNResnet101FeatureExtractor(bool is_training, | |||||
int first_stage_features_stride, | |||||
bool batch_norm_trainable = false, | |||||
bool reuse_weights = false, | |||||
float weight_decay = 0.0f, | |||||
IActivation activation_fn = null) : base("resnet_v1_101", | |||||
ResNetV1.resnet_v1_101, | |||||
is_training, | |||||
first_stage_features_stride, | |||||
batch_norm_trainable: batch_norm_trainable, | |||||
reuse_weights: reuse_weights, | |||||
weight_decay: weight_decay, | |||||
activation_fn: activation_fn) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,28 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Operations.Activation; | |||||
namespace Tensorflow.Models.ObjectDetection | |||||
{ | |||||
public class FasterRCNNResnetV1FeatureExtractor : FasterRCNNFeatureExtractor | |||||
{ | |||||
public FasterRCNNResnetV1FeatureExtractor(string architecture, | |||||
Action resnet_model, | |||||
bool is_training, | |||||
int first_stage_features_stride, | |||||
bool batch_norm_trainable = false, | |||||
bool reuse_weights = false, | |||||
float weight_decay = 0.0f, | |||||
IActivation activation_fn = null) : base(is_training, | |||||
first_stage_features_stride, | |||||
batch_norm_trainable: batch_norm_trainable, | |||||
reuse_weights: reuse_weights, | |||||
weight_decay: weight_decay) | |||||
{ | |||||
if (activation_fn == null) | |||||
activation_fn = tf.nn.relu(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Models.Slim.Nets | |||||
{ | |||||
public class ResNetV1 | |||||
{ | |||||
public static void resnet_v1_101() | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -73,7 +73,6 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
} | } | ||||
public void Train(Session sess) | public void Train(Session sess) | ||||