@@ -334,5 +334,15 @@ namespace Tensorflow | |||
return true; | |||
return false; | |||
} | |||
public static Func<Tin1, Tout> partial<Tin1, Tout>(Func<Tin1, Tout> func, Tin1 args) | |||
{ | |||
Func<Tin1, Tout> newfunc = (args1) => | |||
{ | |||
return func(args1); | |||
}; | |||
return newfunc; | |||
} | |||
} | |||
} |
@@ -8,7 +8,7 @@ namespace Tensorflow | |||
{ | |||
public class shape_utils | |||
{ | |||
public static Tensor static_or_dynamic_map_fn(Func<Tensor, Tensor> fn, Tensor elems, TF_DataType dtype = TF_DataType.DtInvalid, | |||
public static Tensor static_or_dynamic_map_fn(Func<Tensor, Tensor> fn, Tensor elems, TF_DataType[] dtypes = null, | |||
int parallel_iterations = 32, bool back_prop = true) | |||
{ | |||
var outputs = tf.unstack(elems).Select(arg => fn(arg)).ToArray(); | |||
@@ -8,7 +8,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
public class DatasetBuilder | |||
{ | |||
public static DatasetV1Adapter build(InputReader input_reader_config, | |||
public DatasetV1Adapter build(InputReader input_reader_config, | |||
int batch_size = 0, | |||
Action transform_input_data_fn = null) | |||
{ | |||
@@ -21,5 +21,10 @@ namespace Tensorflow.Models.ObjectDetection | |||
throw new NotImplementedException(""); | |||
} | |||
public Dictionary<string, Tensor> process_fn(Tensor value) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -1,6 +1,8 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Models.ObjectDetection.Core; | |||
using Tensorflow.Models.ObjectDetection.Protos; | |||
using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer; | |||
@@ -18,33 +20,46 @@ namespace Tensorflow.Models.ObjectDetection | |||
/// </summary> | |||
/// <param name="image_resizer_config"></param> | |||
/// <returns></returns> | |||
public Action build(ImageResizer image_resizer_config) | |||
public Func<ResizeToRangeArgs, Tensor[]> build(ImageResizer image_resizer_config) | |||
{ | |||
var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; | |||
if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer) | |||
{ | |||
var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; | |||
if (keep_aspect_ratio_config.MinDimension > keep_aspect_ratio_config.MaxDimension) | |||
throw new ValueError("min_dimension > max_dimension"); | |||
var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); | |||
var per_channel_pad_value = new[] { 0, 0, 0 }; | |||
if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0) | |||
throw new NotImplementedException(""); | |||
// per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. }; | |||
return () => | |||
var args = new ResizeToRangeArgs | |||
{ | |||
min_dimension = keep_aspect_ratio_config.MinDimension, | |||
max_dimension = keep_aspect_ratio_config.MaxDimension, | |||
method = method, | |||
pad_to_max_dimension = keep_aspect_ratio_config.PadToMaxDimension, | |||
per_channel_pad_value = per_channel_pad_value | |||
}; | |||
Func<ResizeToRangeArgs, Tensor[]> func = (input) => | |||
{ | |||
args.image = input.image; | |||
return Preprocessor.resize_to_range(args); | |||
}; | |||
return func; | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
return null; | |||
} | |||
private ResizeType _tf_resize_method(ResizeType resize_method) | |||
private ResizeMethod _tf_resize_method(ResizeType resize_method) | |||
{ | |||
return resize_method; | |||
return (ResizeMethod)(int)resize_method; | |||
} | |||
} | |||
} |
@@ -8,11 +8,12 @@ namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
public class ModelBuilder | |||
{ | |||
ImageResizerBuilder image_resizer_builder; | |||
ImageResizerBuilder _image_resizer_builder; | |||
FasterRCNNFeatureExtractor _feature_extractor; | |||
public ModelBuilder() | |||
{ | |||
image_resizer_builder = new ImageResizerBuilder(); | |||
_image_resizer_builder = new ImageResizerBuilder(); | |||
} | |||
/// <summary> | |||
@@ -43,7 +44,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries) | |||
{ | |||
var num_classes = frcnn_config.NumClasses; | |||
var image_resizer_fn = image_resizer_builder.build(frcnn_config.ImageResizer); | |||
var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); | |||
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | |||
var number_of_stages = frcnn_config.NumberOfStages; | |||
@@ -53,7 +54,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
is_training = is_training, | |||
num_classes = num_classes, | |||
image_resizer_fn = image_resizer_fn, | |||
feature_extractor = () => { throw new NotImplementedException(""); }, | |||
feature_extractor = _feature_extractor, | |||
number_of_stage = number_of_stages, | |||
first_stage_anchor_generator = null, | |||
first_stage_atrous_rate = first_stage_atrous_rate | |||
@@ -1,7 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.tensorflow.image_internal; | |||
namespace Tensorflow.Models.ObjectDetection.Core | |||
{ | |||
@@ -9,13 +9,12 @@ namespace Tensorflow.Models.ObjectDetection | |||
public class Inputs | |||
{ | |||
ModelBuilder modelBuilder; | |||
Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>> INPUT_BUILDER_UTIL_MAP; | |||
DatasetBuilder datasetBuilder; | |||
public Inputs() | |||
{ | |||
modelBuilder = new ModelBuilder(); | |||
INPUT_BUILDER_UTIL_MAP = new Dictionary<string, Func<DetectionModel, bool, bool, FasterRCNNMetaArch>>(); | |||
INPUT_BUILDER_UTIL_MAP["model_build"] = modelBuilder.build; | |||
datasetBuilder = new DatasetBuilder(); | |||
} | |||
public Func<DatasetV1Adapter> create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | |||
@@ -35,12 +34,27 @@ namespace Tensorflow.Models.ObjectDetection | |||
/// <returns></returns> | |||
public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) | |||
{ | |||
var arch = INPUT_BUILDER_UTIL_MAP["model_build"](model_config, true, true); | |||
var arch = modelBuilder.build(model_config, true, true); | |||
Func<Tensor, (Tensor, Tensor)> model_preprocess_fn = arch.preprocess; | |||
var dataset = DatasetBuilder.build(train_input_config); | |||
Func<Dictionary<string, Tensor>, (Dictionary<string, Tensor>, Dictionary<string, Tensor>) > transform_and_pad_input_data_fn = (tensor_dict) => | |||
{ | |||
return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict)); | |||
}; | |||
var dataset = datasetBuilder.build(train_input_config); | |||
return dataset; | |||
} | |||
private Dictionary<string, Tensor> _get_features_dict(Dictionary<string, Tensor> input_dict) | |||
{ | |||
throw new NotImplementedException("_get_features_dict"); | |||
} | |||
private Dictionary<string, Tensor> _get_labels_dict(Dictionary<string, Tensor> input_dict) | |||
{ | |||
throw new NotImplementedException("_get_labels_dict"); | |||
} | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
public class FasterRCNNFeatureExtractor | |||
{ | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Models.ObjectDetection.Core; | |||
namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
@@ -8,11 +9,12 @@ namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
public bool is_training { get; set; } | |||
public int num_classes { get; set; } | |||
public Action image_resizer_fn { get; set; } | |||
public Action feature_extractor { get; set; } | |||
public Func<ResizeToRangeArgs, Tensor[]> image_resizer_fn { get; set; } | |||
public FasterRCNNFeatureExtractor feature_extractor { get; set; } | |||
public int number_of_stage { get; set; } | |||
public object first_stage_anchor_generator { get; set; } | |||
public object first_stage_target_assigner { get; set; } | |||
public int first_stage_atrous_rate { get; set; } | |||
public int parallel_iterations { get; set; } = 16; | |||
} | |||
} |
@@ -23,12 +23,17 @@ namespace Tensorflow.Models.ObjectDetection | |||
{ | |||
tf_with(tf.name_scope("Preprocessor"), delegate | |||
{ | |||
/*var outputs = shape_utils.static_or_dynamic_map_fn( | |||
_image_resizer_fn, | |||
var outputs = shape_utils.static_or_dynamic_map_fn( | |||
(inputs1) => | |||
{ | |||
return _args.image_resizer_fn(new Core.ResizeToRangeArgs | |||
{ | |||
image = inputs1 | |||
})[0]; | |||
}, | |||
elems: inputs, | |||
dtype: new[] { tf.float32, tf.int32 }, | |||
parallel_iterations: _parallel_iterations);*/ | |||
dtypes: new[] { tf.float32, tf.int32 }, | |||
parallel_iterations: _args.parallel_iterations); | |||
}); | |||
throw new NotImplementedException(""); | |||