From 954713f46f2b82aac385b2a6886d9480892ee3f2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 8 Jun 2019 10:46:12 -0500 Subject: [PATCH] add Conv2dParams --- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 6 ++- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 8 +-- src/TensorFlowNET.Core/Layers/Layer.cs | 7 --- .../Operations/NnOps/Conv2dParams.cs | 53 +++++++++++++++++++ .../Operations/NnOps/_NonAtrousConvolution.cs | 16 +++--- .../Operations/NnOps/gen_nn_ops.cs | 48 ++++++++++------- src/TensorFlowNET.Core/Operations/nn_ops.cs | 1 + .../Variables/RefVariable.cs | 2 +- .../Variables/VariableScope.cs | 2 +- src/TensorFlowNET.Core/ops.py.cs | 2 +- 10 files changed, 102 insertions(+), 43 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index efa79e00..aeb37128 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -56,7 +56,9 @@ namespace Tensorflow.Keras.Layers protected override void build(TensorShape input_shape) { int channel_axis = data_format == "channels_first" ? 1 : -1; - int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; + int input_dim = channel_axis < 0 ? + input_shape.Dimensions[input_shape.NDim + channel_axis] : + input_shape.Dimensions[channel_axis]; var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; kernel = add_weight(name: "kernel", shape: kernel_shape, @@ -102,7 +104,7 @@ namespace Tensorflow.Keras.Layers } else { - outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC"); + outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index db089959..a0a151ef 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -206,12 +206,14 @@ namespace Tensorflow.Keras.Layers _updates.AddRange(updates_op); } + // Determine layer name (non-unique). protected virtual void _init_set_name(string name, bool zero_based = true) { + var base_name = name; + _name = name; if (name == null) - _name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(this.GetType().Name), zero_based: zero_based); - else - _name = name; + (_name, base_name) = _make_unique_name(); + _base_name = base_name; } protected virtual (string, string) _make_unique_name() diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index e8844d46..ff0fd28e 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -67,13 +67,6 @@ namespace Tensorflow.Layers return outputs; } - protected override void _init_set_name(string name, bool zero_based = true) - { - // Determine layer name (non-unique). - base._init_set_name(name, zero_based: zero_based); - _base_name = this.name; - } - protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) { foreach(var name in collection_list) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs new file mode 100644 index 00000000..de480692 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class Conv2dParams + { + public string Name { get; set; } + + /// + /// An optional `string` from: `"NHWC", "NCHW"`. Defaults to `"NHWC"`. + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// + public string DataFormat { get; set; } = "NHWC"; + + /// + /// Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`. + /// A 4-D tensor. The dimension order is interpreted according to the value + /// + public Tensor Input { get; set; } + + /// + /// A 4-D tensor of shape + /// + public Tensor Filter { get; set; } + + /// + /// The stride of the sliding window for each + /// dimension of `input`. The dimension order is determined by the value of + /// `data_format`, see below for details. + /// + public int[] Strides { get; set; } + + /// + /// A `string` from: `"SAME", "VALID", "EXPLICIT"`. + /// + public string Padding { get; set; } + + public int[] ExplicitPaddings { get; set; } = new int[0]; + + public bool UseCudnnOnGpu { get; set; } = true; + + public int[] Dilations { get; set; } = new [] { 1, 1, 1, 1 }; + + public Conv2dParams() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs index e6a3958a..c742884a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Operations public string name; public int[] strides; public string data_format; - private Func conv_op; + private Func conv_op; public _NonAtrousConvolution(TensorShape input_shape, TensorShape filter_shape, @@ -55,14 +55,14 @@ namespace Tensorflow.Operations public Tensor __call__(Tensor inp, RefVariable filter) { - return conv_op(new + return conv_op(new Conv2dParams { - input = inp, - filter, - strides, - padding, - data_format, - name + Input = inp, + Filter = filter, + Strides = strides, + Padding = padding, + DataFormat = data_format, + Name = name }); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 9dfc882e..7e985a01 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -10,28 +10,36 @@ namespace Tensorflow.Operations { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); - public static Tensor conv2d(object parameters) + /// + /// Computes a 2-D convolution given 4-D `input` and `filter` tensors. + /// + /// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` + /// and a filter / kernel tensor of shape + /// `[filter_height, filter_width, in_channels, out_channels]`, this op + /// performs the following: + /// + /// 1. Flattens the filter to a 2-D matrix with shape + /// `[filter_height * filter_width * in_channels, output_channels]`. + /// 2. Extracts image patches from the input tensor to form a *virtual* + /// tensor of shape `[batch, out_height, out_width, + /// filter_height * filter_width * in_channels]`. + /// 3. For each patch, right-multiplies the filter matrix and the image patch + /// vector. + /// + /// + /// + public static Tensor conv2d(Conv2dParams parameters) { - var args = Python.ConvertToDict(parameters); - - var input = args["input"]; - var filter = args["filter"]; - var strides = args["strides"]; - var padding = args["padding"]; - var name = args["name"]; - var data_format = args.ContainsKey("data_format") ? args["data_format"] : "NHWC"; - var use_cudnn_on_gpu = args.ContainsKey("use_cudnn_on_gpu") ? args["use_cudnn_on_gpu"] : true; - var dilations = args.ContainsKey("dilations") ? args["dilations"] : new int[] { 1, 1, 1, 1 }; - - var _op = _op_def_lib._apply_op_helper("Conv2D", name: name?.ToString(), args: new + var _op = _op_def_lib._apply_op_helper("Conv2D", name: parameters.Name, args: new { - input, - filter, - strides, - padding, - use_cudnn_on_gpu, - data_format, - dilations + input = parameters.Input, + filter = parameters.Filter, + strides = parameters.Strides, + padding = parameters.Padding, + use_cudnn_on_gpu = parameters.UseCudnnOnGpu, + explicit_paddings = parameters.ExplicitPaddings, + data_format = parameters.DataFormat, + dilations = parameters.Dilations }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index b778bfae..69d5f166 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -37,6 +37,7 @@ namespace Tensorflow { return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => { + name = scope; value = ops.convert_to_tensor(value, name: "input"); var bias_tensor = ops.convert_to_tensor(bias, dtype: value.dtype, name: "bias"); return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 3d586f4b..7b1de63c 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -188,7 +188,7 @@ namespace Tensorflow public Tensor _as_graph_element() => _variable; - public Tensor _TensorConversionFunction(bool as_ref = false) + public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) { if (as_ref) return _ref(); diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index fe6f973e..60cd7777 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -40,7 +40,7 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation= VariableAggregation.None) { - string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name; + string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; return with(ops.name_scope(null), scope => { if (dtype == TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 20dbf668..48c7909c 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -473,7 +473,7 @@ namespace Tensorflow case Tensor[] tensors: return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name); case RefVariable varVal: - return varVal._TensorConversionFunction(as_ref: as_ref); + return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); case ResourceVariable varVal: return null; case object[] objects: