diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index def78327..c3079e77 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -334,5 +334,15 @@ namespace Tensorflow return true; return false; } + + public static Func partial(Func func, Tin1 args) + { + Func newfunc = (args1) => + { + return func(args1); + }; + + return newfunc; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/shape_utils.cs b/src/TensorFlowNET.Core/Tensors/shape_utils.cs index 859d931b..0974dc5b 100644 --- a/src/TensorFlowNET.Core/Tensors/shape_utils.cs +++ b/src/TensorFlowNET.Core/Tensors/shape_utils.cs @@ -8,7 +8,7 @@ namespace Tensorflow { public class shape_utils { - public static Tensor static_or_dynamic_map_fn(Func fn, Tensor elems, TF_DataType dtype = TF_DataType.DtInvalid, + public static Tensor static_or_dynamic_map_fn(Func fn, Tensor elems, TF_DataType[] dtypes = null, int parallel_iterations = 32, bool back_prop = true) { var outputs = tf.unstack(elems).Select(arg => fn(arg)).ToArray(); diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs index ded41f0e..3c47bf51 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Models.ObjectDetection { public class DatasetBuilder { - public static DatasetV1Adapter build(InputReader input_reader_config, + public DatasetV1Adapter build(InputReader input_reader_config, int batch_size = 0, Action transform_input_data_fn = null) { @@ -21,5 +21,10 @@ namespace Tensorflow.Models.ObjectDetection throw new NotImplementedException(""); } + + public Dictionary process_fn(Tensor value) + { + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs index 81c169b3..5c52bd87 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Models.ObjectDetection.Core; using Tensorflow.Models.ObjectDetection.Protos; using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer; @@ -18,33 +20,46 @@ namespace Tensorflow.Models.ObjectDetection /// /// /// - public Action build(ImageResizer image_resizer_config) + public Func build(ImageResizer image_resizer_config) { var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer) { var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; + if (keep_aspect_ratio_config.MinDimension > keep_aspect_ratio_config.MaxDimension) + throw new ValueError("min_dimension > max_dimension"); var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); var per_channel_pad_value = new[] { 0, 0, 0 }; if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0) throw new NotImplementedException(""); // per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. }; - return () => + + var args = new ResizeToRangeArgs { + min_dimension = keep_aspect_ratio_config.MinDimension, + max_dimension = keep_aspect_ratio_config.MaxDimension, + method = method, + pad_to_max_dimension = keep_aspect_ratio_config.PadToMaxDimension, + per_channel_pad_value = per_channel_pad_value + }; + Func func = (input) => + { + args.image = input.image; + return Preprocessor.resize_to_range(args); }; + + return func; } else { throw new NotImplementedException(""); } - - return null; } - private ResizeType _tf_resize_method(ResizeType resize_method) + private ResizeMethod _tf_resize_method(ResizeType resize_method) { - return resize_method; + return (ResizeMethod)(int)resize_method; } } } diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs index 0ff80561..1f678098 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs @@ -8,11 +8,12 @@ namespace Tensorflow.Models.ObjectDetection { public class ModelBuilder { - ImageResizerBuilder image_resizer_builder; + ImageResizerBuilder _image_resizer_builder; + FasterRCNNFeatureExtractor _feature_extractor; public ModelBuilder() { - image_resizer_builder = new ImageResizerBuilder(); + _image_resizer_builder = new ImageResizerBuilder(); } /// @@ -43,7 +44,7 @@ namespace Tensorflow.Models.ObjectDetection private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries) { var num_classes = frcnn_config.NumClasses; - var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); + var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; var number_of_stages = frcnn_config.NumberOfStages; @@ -53,7 +54,7 @@ namespace Tensorflow.Models.ObjectDetection is_training = is_training, num_classes = num_classes, image_resizer_fn = image_resizer_fn, - feature_extractor = () => { throw new NotImplementedException(""); }, + feature_extractor = _feature_extractor, number_of_stage = number_of_stages, first_stage_anchor_generator = null, first_stage_atrous_rate = first_stage_atrous_rate diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs index 2fe799a6..1a3c8eb5 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Text; -using static Tensorflow.tensorflow.image_internal; namespace Tensorflow.Models.ObjectDetection.Core { diff --git a/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs index 0388b78a..34845786 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs @@ -9,13 +9,12 @@ namespace Tensorflow.Models.ObjectDetection public class Inputs { ModelBuilder modelBuilder; - Dictionary> INPUT_BUILDER_UTIL_MAP; + DatasetBuilder datasetBuilder; public Inputs() { modelBuilder = new ModelBuilder(); - INPUT_BUILDER_UTIL_MAP = new Dictionary>(); - INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build; + datasetBuilder = new DatasetBuilder(); } public Func create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) @@ -35,12 +34,27 @@ namespace Tensorflow.Models.ObjectDetection /// public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) { - var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true); + var arch = modelBuilder.build(model_config, true, true); Func model_preprocess_fn = arch.preprocess; - var dataset = DatasetBuilder.build(train_input_config); + Func, (Dictionary, Dictionary) > transform_and_pad_input_data_fn = (tensor_dict) => + { + return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict)); + }; + + var dataset = datasetBuilder.build(train_input_config); return dataset; } + + private Dictionary _get_features_dict(Dictionary input_dict) + { + throw new NotImplementedException("_get_features_dict"); + } + + private Dictionary _get_labels_dict(Dictionary input_dict) + { + throw new NotImplementedException("_get_labels_dict"); + } } } diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs new file mode 100644 index 00000000..bc1f0d46 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class FasterRCNNFeatureExtractor + { + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs index 991ffff4..e5e92161 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Models.ObjectDetection.Core; namespace Tensorflow.Models.ObjectDetection { @@ -8,11 +9,12 @@ namespace Tensorflow.Models.ObjectDetection { public bool is_training { get; set; } public int num_classes { get; set; } - public Action image_resizer_fn { get; set; } - public Action feature_extractor { get; set; } + public Func image_resizer_fn { get; set; } + public FasterRCNNFeatureExtractor feature_extractor { get; set; } public int number_of_stage { get; set; } public object first_stage_anchor_generator { get; set; } public object first_stage_target_assigner { get; set; } public int first_stage_atrous_rate { get; set; } + public int parallel_iterations { get; set; } = 16; } } diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs index beb18198..956960b0 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs @@ -23,12 +23,17 @@ namespace Tensorflow.Models.ObjectDetection { tf_with(tf.name_scope("Preprocessor"), delegate { - /*var outputs = shape_utils.static_or_dynamic_map_fn( - _image_resizer_fn, + var outputs = shape_utils.static_or_dynamic_map_fn( + (inputs1) => + { + return _args.image_resizer_fn(new Core.ResizeToRangeArgs + { + image = inputs1 + })[0]; + }, elems: inputs, - dtype: new[] { tf.float32, tf.int32 }, - parallel_iterations: _parallel_iterations);*/ - + dtypes: new[] { tf.float32, tf.int32 }, + parallel_iterations: _args.parallel_iterations); }); throw new NotImplementedException("");