Browse Source

FasterRCNNFeatureExtractor

tags/v0.12
Oceania2018 6 years ago
parent
commit
eaef3aa9e6
10 changed files with 86 additions and 25 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Tensors/shape_utils.cs
  3. +6
    -1
      src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs
  4. +21
    -6
      src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs
  5. +5
    -4
      src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs
  6. +0
    -1
      src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs
  7. +19
    -5
      src/TensorFlowNET.Models/ObjectDetection/Inputs.cs
  8. +10
    -0
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs
  9. +4
    -2
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs
  10. +10
    -5
      src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs

+ 10
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -334,5 +334,15 @@ namespace Tensorflow
return true;
return false;
}

public static Func<Tin1, Tout> partial<Tin1, Tout>(Func<Tin1, Tout> func, Tin1 args)
{
Func<Tin1, Tout> newfunc = (args1) =>
{
return func(args1);
};

return newfunc;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/shape_utils.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
{
public class shape_utils
{
public static Tensor static_or_dynamic_map_fn(Func<Tensor, Tensor> fn, Tensor elems, TF_DataType dtype = TF_DataType.DtInvalid,
public static Tensor static_or_dynamic_map_fn(Func<Tensor, Tensor> fn, Tensor elems, TF_DataType[] dtypes = null,
int parallel_iterations = 32, bool back_prop = true)
{
var outputs = tf.unstack(elems).Select(arg => fn(arg)).ToArray();


+ 6
- 1
src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Models.ObjectDetection
{
public class DatasetBuilder
{
public static DatasetV1Adapter build(InputReader input_reader_config,
public DatasetV1Adapter build(InputReader input_reader_config,
int batch_size = 0,
Action transform_input_data_fn = null)
{
@@ -21,5 +21,10 @@ namespace Tensorflow.Models.ObjectDetection

throw new NotImplementedException("");
}

public Dictionary<string, Tensor> process_fn(Tensor value)
{
throw new NotImplementedException("");
}
}
}

+ 21
- 6
src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs View File

@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Models.ObjectDetection.Core;
using Tensorflow.Models.ObjectDetection.Protos;
using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer;

@@ -18,33 +20,46 @@ namespace Tensorflow.Models.ObjectDetection
/// </summary>
/// <param name="image_resizer_config"></param>
/// <returns></returns>
public Action build(ImageResizer image_resizer_config)
public Func<ResizeToRangeArgs, Tensor[]> build(ImageResizer image_resizer_config)
{
var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase;
if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer)
{
var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer;
if (keep_aspect_ratio_config.MinDimension > keep_aspect_ratio_config.MaxDimension)
throw new ValueError("min_dimension > max_dimension");
var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod);
var per_channel_pad_value = new[] { 0, 0, 0 };
if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0)
throw new NotImplementedException("");
// per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. };
return () =>

var args = new ResizeToRangeArgs
{
min_dimension = keep_aspect_ratio_config.MinDimension,
max_dimension = keep_aspect_ratio_config.MaxDimension,
method = method,
pad_to_max_dimension = keep_aspect_ratio_config.PadToMaxDimension,
per_channel_pad_value = per_channel_pad_value
};

Func<ResizeToRangeArgs, Tensor[]> func = (input) =>
{
args.image = input.image;
return Preprocessor.resize_to_range(args);
};

return func;
}
else
{
throw new NotImplementedException("");
}

return null;
}

private ResizeType _tf_resize_method(ResizeType resize_method)
private ResizeMethod _tf_resize_method(ResizeType resize_method)
{
return resize_method;
return (ResizeMethod)(int)resize_method;
}
}
}

+ 5
- 4
src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs View File

@@ -8,11 +8,12 @@ namespace Tensorflow.Models.ObjectDetection
{
public class ModelBuilder
{
ImageResizerBuilder image_resizer_builder;
ImageResizerBuilder _image_resizer_builder;
FasterRCNNFeatureExtractor _feature_extractor;

public ModelBuilder()
{
image_resizer_builder = new ImageResizerBuilder();
_image_resizer_builder = new ImageResizerBuilder();
}

/// <summary>
@@ -43,7 +44,7 @@ namespace Tensorflow.Models.ObjectDetection
private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries)
{
var num_classes = frcnn_config.NumClasses;
var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer);
var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer);

var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate;
var number_of_stages = frcnn_config.NumberOfStages;
@@ -53,7 +54,7 @@ namespace Tensorflow.Models.ObjectDetection
is_training = is_training,
num_classes = num_classes,
image_resizer_fn = image_resizer_fn,
feature_extractor = () => { throw new NotImplementedException(""); },
feature_extractor = _feature_extractor,
number_of_stage = number_of_stages,
first_stage_anchor_generator = null,
first_stage_atrous_rate = first_stage_atrous_rate


+ 0
- 1
src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs View File

@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.tensorflow.image_internal;

namespace Tensorflow.Models.ObjectDetection.Core
{


+ 19
- 5
src/TensorFlowNET.Models/ObjectDetection/Inputs.cs View File

@@ -9,13 +9,12 @@ namespace Tensorflow.Models.ObjectDetection
public class Inputs
{
ModelBuilder modelBuilder;
Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>> INPUT_BUILDER_UTIL_MAP;
DatasetBuilder datasetBuilder;

public Inputs()
{
modelBuilder = new ModelBuilder();
INPUT_BUILDER_UTIL_MAP = new Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>>();
INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build;
datasetBuilder = new DatasetBuilder();
}

public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
@@ -35,12 +34,27 @@ namespace Tensorflow.Models.ObjectDetection
/// <returns></returns>
public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config)
{
var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true);
var arch = modelBuilder.build(model_config, true, true);
Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess;

var dataset = DatasetBuilder.build(train_input_config);
Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>) > transform_and_pad_input_data_fn = (tensor_dict) =>
{
return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict));
};

var dataset = datasetBuilder.build(train_input_config);

return dataset;
}

private Dictionary<string, Tensor> _get_features_dict(Dictionary<string, Tensor> input_dict)
{
throw new NotImplementedException("_get_features_dict");
}

private Dictionary<string, Tensor> _get_labels_dict(Dictionary<string, Tensor> input_dict)
{
throw new NotImplementedException("_get_labels_dict");
}
}
}

+ 10
- 0
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Models.ObjectDetection
{
public class FasterRCNNFeatureExtractor
{
}
}

+ 4
- 2
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Models.ObjectDetection.Core;

namespace Tensorflow.Models.ObjectDetection
{
@@ -8,11 +9,12 @@ namespace Tensorflow.Models.ObjectDetection
{
public bool is_training { get; set; }
public int num_classes { get; set; }
public Action image_resizer_fn { get; set; }
public Action feature_extractor { get; set; }
public Func<ResizeToRangeArgs, Tensor[]> image_resizer_fn { get; set; }
public FasterRCNNFeatureExtractor feature_extractor { get; set; }
public int number_of_stage { get; set; }
public object first_stage_anchor_generator { get; set; }
public object first_stage_target_assigner { get; set; }
public int first_stage_atrous_rate { get; set; }
public int parallel_iterations { get; set; } = 16;
}
}

+ 10
- 5
src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs View File

@@ -23,12 +23,17 @@ namespace Tensorflow.Models.ObjectDetection
{
tf_with(tf.name_scope("Preprocessor"), delegate
{
/*var outputs = shape_utils.static_or_dynamic_map_fn(
_image_resizer_fn,
var outputs = shape_utils.static_or_dynamic_map_fn(
(inputs1) =>
{
return _args.image_resizer_fn(new Core.ResizeToRangeArgs
{
image = inputs1
})[0];
},
elems: inputs,
dtype: new[] { tf.float32, tf.int32 },
parallel_iterations: _parallel_iterations);*/

dtypes: new[] { tf.float32, tf.int32 },
parallel_iterations: _args.parallel_iterations);
});

throw new NotImplementedException("");


Loading…
Cancel
Save