Browse Source

FasterRCNNFeatureExtractor

tags/v0.12
Oceania2018 6 years ago
parent
commit
4c2090c470
9 changed files with 167 additions and 2 deletions
  1. +15
    -0
      src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs
  2. +27
    -0
      src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs
  3. +20
    -1
      src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
  4. +10
    -0
      src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs
  5. +21
    -0
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs
  6. +32
    -0
      src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs
  7. +28
    -0
      src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs
  8. +14
    -0
      src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs
  9. +0
    -1
      test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs

+ 15
- 0
src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection
{
public class GridAnchorGenerator : Core.AnchorGenerator
{
public GridAnchorGenerator(float[] scales = null)
{
if (scales == null)
scales = new[] { 0.5f, 1.0f, 2.0f };
}
}
}

+ 27
- 0
src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator;

namespace Tensorflow.Models.ObjectDetection
{
public class AnchorGeneratorBuilder
{
public AnchorGeneratorBuilder()
{

}

public GridAnchorGenerator build(AnchorGenerator anchor_generator_config)
{
if(anchor_generator_config.AnchorGeneratorOneofCase == AnchorGeneratorOneofOneofCase.GridAnchorGenerator)
{
var grid_anchor_generator_config = anchor_generator_config.GridAnchorGenerator;
return new GridAnchorGenerator(scales: grid_anchor_generator_config.Scales.Select(x => float.Parse(x.ToString())).ToArray());
}
throw new NotImplementedException("");
}
}
}

+ 20
- 1
src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow.Models.ObjectDetection
{ {
ImageResizerBuilder _image_resizer_builder; ImageResizerBuilder _image_resizer_builder;
FasterRCNNFeatureExtractor _feature_extractor; FasterRCNNFeatureExtractor _feature_extractor;
AnchorGeneratorBuilder anchor_generator_builder;


public ModelBuilder() public ModelBuilder()
{ {
@@ -46,8 +47,12 @@ namespace Tensorflow.Models.ObjectDetection
var num_classes = frcnn_config.NumClasses; var num_classes = frcnn_config.NumClasses;
var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer);


var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;
var feature_extractor = _build_faster_rcnn_feature_extractor(frcnn_config.FeatureExtractor, is_training,
inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate);

var number_of_stages = frcnn_config.NumberOfStages; var number_of_stages = frcnn_config.NumberOfStages;
var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator);
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;


return new FasterRCNNMetaArch(new FasterRCNNInitArgs return new FasterRCNNMetaArch(new FasterRCNNInitArgs
{ {
@@ -65,5 +70,19 @@ namespace Tensorflow.Models.ObjectDetection
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }

private FasterRCNNFeatureExtractor _build_faster_rcnn_feature_extractor(FasterRcnnFeatureExtractor feature_extractor_config,
bool is_training, bool reuse_weights = false, bool inplace_batchnorm_update = false)
{
if (inplace_batchnorm_update)
throw new ValueError("inplace batchnorm updates not supported.");
var feature_type = feature_extractor_config.Type;
var first_stage_features_stride = feature_extractor_config.FirstStageFeaturesStride;
var batch_norm_trainable = feature_extractor_config.BatchNormTrainable;

return new FasterRCNNResnet101FeatureExtractor(is_training, first_stage_features_stride,
batch_norm_trainable: batch_norm_trainable,
reuse_weights: reuse_weights);
}
} }
} }

+ 10
- 0
src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection.Core
{
public class AnchorGenerator
{
}
}

+ 21
- 0
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs View File

@@ -4,7 +4,28 @@ using System.Text;


namespace Tensorflow.Models.ObjectDetection namespace Tensorflow.Models.ObjectDetection
{ {
/// <summary>
/// Faster R-CNN Feature Extractor definition.
/// </summary>
public class FasterRCNNFeatureExtractor public class FasterRCNNFeatureExtractor
{ {
bool _is_training;
int _first_stage_features_stride;
bool _reuse_weights = false;
float _weight_decay = 0.0f;
bool _train_batch_norm;

public FasterRCNNFeatureExtractor(bool is_training,
int first_stage_features_stride,
bool batch_norm_trainable = false,
bool reuse_weights = false,
float weight_decay = 0.0f)
{
_is_training = is_training;
_first_stage_features_stride = first_stage_features_stride;
_train_batch_norm = (batch_norm_trainable && is_training);
_reuse_weights = reuse_weights;
_weight_decay = weight_decay;
}
} }
} }

+ 32
- 0
src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs View File

@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Operations.Activation;
using Tensorflow.Models.Slim.Nets;

namespace Tensorflow.Models.ObjectDetection
{
/// <summary>
/// Faster R-CNN Resnet 101 feature extractor implementation.
/// </summary>
public class FasterRCNNResnet101FeatureExtractor : FasterRCNNResnetV1FeatureExtractor
{
public FasterRCNNResnet101FeatureExtractor(bool is_training,
int first_stage_features_stride,
bool batch_norm_trainable = false,
bool reuse_weights = false,
float weight_decay = 0.0f,
IActivation activation_fn = null) : base("resnet_v1_101",
ResNetV1.resnet_v1_101,
is_training,
first_stage_features_stride,
batch_norm_trainable: batch_norm_trainable,
reuse_weights: reuse_weights,
weight_decay: weight_decay,
activation_fn: activation_fn)
{

}
}
}

+ 28
- 0
src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Operations.Activation;

namespace Tensorflow.Models.ObjectDetection
{
public class FasterRCNNResnetV1FeatureExtractor : FasterRCNNFeatureExtractor
{
public FasterRCNNResnetV1FeatureExtractor(string architecture,
Action resnet_model,
bool is_training,
int first_stage_features_stride,
bool batch_norm_trainable = false,
bool reuse_weights = false,
float weight_decay = 0.0f,
IActivation activation_fn = null) : base(is_training,
first_stage_features_stride,
batch_norm_trainable: batch_norm_trainable,
reuse_weights: reuse_weights,
weight_decay: weight_decay)
{
if (activation_fn == null)
activation_fn = tf.nn.relu();
}
}
}

+ 14
- 0
src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.Slim.Nets
{
public class ResNetV1
{
public static void resnet_v1_101()
{
throw new NotImplementedException("");
}
}
}

+ 0
- 1
test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs View File

@@ -73,7 +73,6 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection


public void PrepareData() public void PrepareData()
{ {
throw new NotImplementedException();
} }


public void Train(Session sess) public void Train(Session sess)


Loading…
Cancel
Save