From 6a8665f111b42531e90922509ec298cde58ea330 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 7 Feb 2021 23:15:57 -0600 Subject: [PATCH] Add Conv2DTranspose #735 --- src/TensorFlowNET.Core/Binding.Util.cs | 5 + .../Framework/smart_module.cs | 11 +- src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 13 ++ .../Keras/Engine/InputSpec.cs | 2 +- .../Operations/nn_impl.py.cs | 42 +++++ .../Tensors/TensorShape.Equals.cs | 4 + src/TensorFlowNET.Keras/BackendImpl.cs | 25 +++ src/TensorFlowNET.Keras/Engine/Functional.cs | 4 +- src/TensorFlowNET.Keras/Engine/Layer.Apply.cs | 6 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 2 +- .../Engine/Model.Evaluate.cs | 2 +- .../Engine/Model.Predict.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 2 +- .../Engine/Model.Training.cs | 9 +- .../Layers/Activation/LeakyReLu.cs | 2 +- .../Layers/Convolution/Conv2DTranspose.cs | 150 ++++++++++++++++++ .../Layers/Convolution/Convolutional.cs | 5 +- src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 2 +- .../Layers/Core/Embedding.cs | 2 +- src/TensorFlowNET.Keras/Layers/LSTM.cs | 4 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 45 ++++++ .../Layers/Merging/Merge.cs | 2 +- .../Normalization/BatchNormalization.cs | 74 ++++++++- .../Layers/Pooling/GlobalAveragePooling2D.cs | 2 +- .../Layers/Pooling/Pooling2D.cs | 2 +- .../Layers/Regularization/Dropout.cs | 7 +- .../Layers/Rescaling/Rescaling.cs | 2 +- .../Layers/Reshaping/Flatten.cs | 2 +- .../Layers/Reshaping/Reshape.cs | 2 +- .../Layers/Reshaping/UpSampling2D.cs | 2 +- .../Layers/Reshaping/ZeroPadding2D.cs | 2 +- .../Layers/TensorFlowOpLayer.cs | 2 +- .../Tensorflow.Keras.csproj | 8 +- src/TensorFlowNET.Keras/Utils/conv_utils.cs | 30 ++++ src/TensorFlowNET.Keras/Utils/tf_utils.cs | 11 ++ 35 files changed, 447 insertions(+), 40 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 535bbca4..d58fbe70 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -149,6 +149,8 @@ namespace Tensorflow return ndArray.ndim == 0 ? 1 : ndArray.shape[0]; case IEnumerable enumerable: return enumerable.OfType().Count(); + case TensorShape arr: + return arr.ndim; } throw new NotImplementedException("len() not implemented for type: " + a.GetType()); } @@ -156,6 +158,9 @@ namespace Tensorflow public static float min(float a, float b) => Math.Min(a, b); + public static int max(int a, int b) + => Math.Max(a, b); + public static T[] list(IEnumerable list) => list.ToArray(); diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs index 7a8654c5..d9e35a6d 100644 --- a/src/TensorFlowNET.Core/Framework/smart_module.cs +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.Framework { @@ -52,7 +54,14 @@ namespace Tensorflow.Framework { var pred_value = tensor_util.constant_value(pred); if (pred_value is null) - return pred.eval(new Session(pred.graph)); + { + var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); + var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle); + if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK) + return null; + else + throw new NotImplementedException(""); + } return pred_value; } diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 280e4bb6..2f5af971 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -322,5 +322,18 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, SafeStatusHandle status); + + /// + /// Attempts to evaluate `output`. This will only be possible if `output` doesn't + /// depend on any graph inputs (this function is safe to call if this isn't the + /// case though). + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 4993fc2a..198e8162 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -50,6 +50,6 @@ namespace Tensorflow.Keras.Engine } public override string ToString() - => $"min_ndim={min_ndim}, , axes={axes.Count}"; + => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; } } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 1da2c252..82fa2acb 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -21,6 +21,31 @@ namespace Tensorflow { public class nn_impl { + public static Tensor conv2d_transpose(Tensor value = null, + IVariableV1 filter = null, + Tensor output_shape = null, + TensorShape strides = null, + string padding = "SAME", + string data_format = "NHWC", + string name = null, + TensorShape dilations = null) + { + if (dilations == null) + dilations = (1, 1, 1, 1); + return tf_with(ops.name_scope(name, "conv2d_transpose", new { value, filter, output_shape }), scope => + { + return gen_nn_ops.conv2d_backprop_input( + input_sizes: output_shape, + filter: filter.AsTensor(), + out_backprop: value, + strides: strides, + padding: padding, + data_format: data_format, + dilations: dilations, + name: name); + }); + } + /// /// Normalizes along dimension `axis` using an L2 norm. /// @@ -83,6 +108,23 @@ namespace Tensorflow }); } + public static Tensor batch_normalization(Tensor x, + Tensor mean, + Tensor variance, + Tensor offset, + Tensor scale, + float variance_epsilon = 0.001f, + string name = null) + { + return tf_with(ops.name_scope(name, "batchnorm", new { x, mean, variance, scale, offset }), scope => + { + var inv = math_ops.rsqrt(variance + variance_epsilon); + inv *= scale; + return x * math_ops.cast(inv, x.dtype) + math_ops.cast( + offset == null ? (-mean * inv) : (offset - mean * inv), x.dtype); + }); + } + /// /// Batch normalization. /// diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs index 9078dbed..d892f750 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs @@ -15,6 +15,10 @@ namespace Tensorflow else if (rank != shape1.rank) return false; return Enumerable.SequenceEqual(shape1.dims, dims); + case int[] shape2: + if (rank != shape2.Length) + return false; + return Enumerable.SequenceEqual(dims, shape2); default: return false; } diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 3a68a61a..c82acce4 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -317,5 +317,30 @@ namespace Tensorflow.Keras return array_ops.concat(tensors, axis); } + + public Tensor conv2d_transpose(Tensor x, + IVariableV1 kernel, + Tensor output_shape, + TensorShape strides = null, + string padding = "valid", + string data_format = null, + TensorShape dilation_rate = null) + { + var force_transpose = false; + if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 })) + force_transpose = true; + // x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose) + var tf_data_format = "NHWC"; + padding = padding.ToUpper(); + strides = new TensorShape(1, strides[0], strides[1], 1); + if (dilation_rate.Equals(new[] { 1, 1 })) + x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides, + padding: padding, + data_format: tf_data_format); + else + throw new NotImplementedException(""); + + return x; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 2f177451..78038cff 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -301,9 +301,9 @@ namespace Tensorflow.Keras.Engine nodes_in_decreasing_depth.append(node); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { - return run_internal_graph(inputs, is_training); + return run_internal_graph(inputs, training.Value); } Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index 3b896b68..b19d5307 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -10,9 +10,9 @@ namespace Tensorflow.Keras.Engine /// /// /// - /// + /// /// - public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) + public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) { callContext = callContext ?? new ThreadLocal() { @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Engine if (!built) MaybeBuild(inputs); - outputs = Call(inputs, state: state, is_training: is_training); + outputs = Call(inputs, state: state, training: training); // memory leak // _set_connectivity_metadata_(inputs, outputs); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index ab47b3dc..fc5d3de9 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -155,7 +155,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 7097670c..8a484c3b 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -73,7 +73,7 @@ namespace Tensorflow.Keras.Engine List<(string, Tensor)> test_step(Tensor x, Tensor y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); - var y_pred = Apply(x, is_training: false); + var y_pred = Apply(x, training: false); var loss = compiled_loss.Call(y, y_pred); compiled_metrics.update_state(y, y_pred); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs index 2971cbb8..8b01d022 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -76,7 +76,7 @@ namespace Tensorflow.Keras.Engine Tensors predict_step(Tensor data) { - return Apply(data, is_training: false); + return Apply(data, training: false); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 961405d5..31e89c57 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); - var y_pred = Apply(x, is_training: true); + var y_pred = Apply(x, training: true); var loss = compiled_loss.Call(y, y_pred); // For custom training steps, users can just write: diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs index 6bf0eed9..23763f07 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -8,9 +8,10 @@ using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { - public partial class Model + public partial class Model { - public List<(IVariableV1, NDArray)> load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) + List<(IVariableV1, NDArray)> LoadedWeights; + public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) { long fileId = Hdf5.OpenFile(filepath, true); @@ -25,10 +26,8 @@ namespace Tensorflow.Keras.Engine throw new NotImplementedException(""); else { - var weights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); + LoadedWeights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); Hdf5.CloseFile(fileId); - // return a reference to prevent GC collect Variable. - return weights; } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs index 625e81d4..1fbbf4ea 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { return tf.nn.leaky_relu(inputs, alpha: alpha); } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs new file mode 100644 index 00000000..ffd4e9b3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -0,0 +1,150 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + public class Conv2DTranspose : Conv2D + { + public Conv2DTranspose(Conv2DArgs args) : base(args) + { + + } + + protected override void build(Tensors inputs) + { + var input_shape = inputs.shape; + if (len(input_shape) != 4) + throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); + + var channel_axis = _get_channel_axis(); + var input_dim = input_shape[-1]; + var kernel_shape = new TensorShape(kernel_size[0], kernel_size[1], filters, input_dim); + + kernel = add_weight(name: "kernel", + shape: kernel_shape, + initializer: kernel_initializer, + regularizer: kernel_regularizer, + trainable: true, + dtype: inputs.dtype); + if (use_bias) + bias = add_weight(name: "bias", + shape: filters, + initializer: bias_initializer, + trainable: true, + dtype: inputs.dtype); + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + var inputs_shape = array_ops.shape(inputs); + var batch_size = inputs_shape[0]; + var (h_axis, w_axis) = (1, 2); + if (data_format == "channels_first") + (h_axis, w_axis) = (2, 3); + var (height, width) = (-1, -1); + if(inputs.shape.rank > -1) + { + var dims = inputs.shape.dims; + (height, width) = (dims[h_axis], dims[w_axis]); + } + var (kernel_h, kernel_w) = kernel_size; + var (stride_h, stride_w) = strides; + + var (out_pad_h, out_pad_w) = (-1, -1); + + // Infer the dynamic output shape: + var out_height = conv_utils.deconv_output_length(height, + kernel_h, + padding: padding, + output_padding: out_pad_h, + stride: stride_h, + dilation: dilation_rate[0]); + + var out_width = conv_utils.deconv_output_length(width, + kernel_w, + padding: padding, + output_padding: out_pad_w, + stride: stride_w, + dilation: dilation_rate[1]); + + Tensor output_shape_tensor; + if (data_format == "channels_first") + output_shape_tensor = array_ops.stack(new object[] { batch_size, filters, out_height, out_width }); + else + output_shape_tensor = array_ops.stack(new object[] { batch_size, out_height, out_width, filters }); + + var outputs = keras.backend.conv2d_transpose( + inputs, + kernel, + output_shape_tensor, + strides: strides, + padding: padding, + data_format: data_format, + dilation_rate: dilation_rate); + + if (!tf.Context.executing_eagerly()) + { + var out_shape = ComputeOutputShape(inputs.shape); + outputs.set_shape(out_shape); + } + + if (use_bias) + throw new NotImplementedException(""); + + if (activation != null) + return activation(outputs); + + return outputs; + } + + public override TensorShape ComputeOutputShape(TensorShape input_shape) + { + var output_shape = input_shape.dims; + var (c_axis, h_axis, w_axis) = (3, 1, 2); + if (data_format == "channels_first") + (c_axis, h_axis, w_axis) = (1, 2, 3); + + var (kernel_h, kernel_w) = kernel_size; + var (stride_h, stride_w) = strides; + + var (out_pad_h, out_pad_w) = (-1, -1); + output_shape[c_axis] = filters; + output_shape[h_axis] = conv_utils.deconv_output_length( + output_shape[h_axis], + kernel_h, + padding: padding, + output_padding: out_pad_h, + stride: stride_h, + dilation: dilation_rate[0]); + output_shape[w_axis] = conv_utils.deconv_output_length( + output_shape[w_axis], + kernel_w, + padding: padding, + output_padding: out_pad_w, + stride: stride_w, + dilation: dilation_rate[1]); + + return new TensorShape(output_shape); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 7d4da4af..2139fd32 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -99,7 +99,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) { var outputs = _convolution_op.Apply(inputs, kernel); if (use_bias) @@ -119,5 +119,8 @@ namespace Tensorflow.Keras.Layers return outputs; } + + protected virtual int _get_channel_axis() + => data_format == "channels_first" ? -1 - rank : -1; } } diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index 7f992c5e..a6334713 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 36bbd152..131d0627 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Keras/Layers/LSTM.cs b/src/TensorFlowNET.Keras/Layers/LSTM.cs index 3db6afba..73a2df12 100644 --- a/src/TensorFlowNET.Keras/Layers/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/LSTM.cs @@ -26,9 +26,9 @@ namespace Tensorflow.Keras.Layers .ToArray(); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { - return base.Call(inputs, state: state, is_training: is_training); + return base.Call(inputs, state: state, training: training); } } } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 3f8fae3d..dc9ee749 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -140,6 +140,51 @@ namespace Tensorflow.Keras.Layers Activation = GetActivationByName(activation) }); + /// + /// Transposed convolution layer (sometimes called Deconvolution). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Conv2DTranspose Conv2DTranspose(int filters, + TensorShape kernel_size = null, + TensorShape strides = null, + string padding = "valid", + string data_format = null, + TensorShape dilation_rate = null, + string activation = null, + bool use_bias = true, + string kernel_initializer = null, + string bias_initializer = null, + string kernel_regularizer = null, + string bias_regularizer = null, + string activity_regularizer = null) + => new Conv2DTranspose(new Conv2DArgs + { + Rank = 2, + Filters = filters, + KernelSize = kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + Activation = GetActivationByName(activation) + }); + public Dense Dense(int units, Activation activation = null, IInitializer kernel_initializer = null, diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index c0fa3f36..be8f574e 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers // output_shape = input_shape.dims[1^]; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { return _merge_function(inputs); } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index d4dbb3d7..8ad1f224 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -36,6 +36,7 @@ namespace Tensorflow.Keras.Layers bool fused; int[] axis; string _data_format; + TensorShape kernel_size; IInitializer beta_initializer => args.BetaInitializer; IInitializer gamma_initializer => args.GammaInitializer; IInitializer moving_mean_initializer => args.MovingMeanInitializer; @@ -120,10 +121,35 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false) + public override TensorShape ComputeOutputShape(TensorShape input_shape) + { + return input_shape; + } + + (Tensor, Tensor) _moments(Tensors inputs, int[] reduction_axes, bool keep_dims) + { + var (mean, variance) = _calculate_mean_and_var(inputs, reduction_axes, keep_dims); + if (_support_zero_size_input()) + throw new NotImplementedException(""); + return (mean, variance); + } + + (Tensor, Tensor) _calculate_mean_and_var(Tensors inputs, int[] reduction_axes, bool keep_dims) + { + return nn_impl.moments(inputs, reduction_axes, keep_dims: keep_dims); + } + + bool _support_zero_size_input() + { + return false; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { Tensor outputs = null; - var training_tensor = tf.logical_and(training, Trainable); + var training_tensor = training == null + ? tf.placeholder(tf.@bool, TensorShape.Scalar) + : tf.logical_and(training.Value, Trainable); if (fused) { // var training = tf.convert_to_tensor(training); @@ -131,7 +157,49 @@ namespace Tensorflow.Keras.Layers return outputs; } - throw new NotImplementedException("BatchNormalization call"); + var inputs_dtype = inputs.dtype.as_base_dtype(); + var input_shape = inputs.shape; + var ndims = len(input_shape); + var reduction_axes = range(ndims).Where(x => !axis.Contains(x)).ToArray(); + + // Broadcasting only necessary for single-axis batch norm where the axis is + // not the last dimension + var broadcast_shape = range(ndims).Select(x => 1).ToArray(); + broadcast_shape[axis[0]] = input_shape.dims[axis[0]]; + + var (scale, offset) = (gamma, beta); + var training_value = tf_utils.constant_value(training_tensor); + + Tensor mean; + Tensor variance; + if (training_value.HasValue && training_value.Value == false) + { + (mean, variance) = (moving_mean.AsTensor(), moving_variance.AsTensor()); + } + else + { + var keep_dims = len(axis) > 1; + (mean, variance) = _moments(inputs, reduction_axes, keep_dims: keep_dims); + mean = tf_utils.smart_cond(training_tensor, + () => new[] { mean }, + () => new[] { ops.convert_to_tensor(moving_mean) }).FirstOrDefault(); + + variance = tf_utils.smart_cond(training_tensor, + () => new[] { variance }, + () => new[] { ops.convert_to_tensor(moving_variance) }).FirstOrDefault(); + + var (new_mean, new_variance) = (mean, variance); + } + + mean = math_ops.cast(mean, inputs.dtype); + variance = math_ops.cast(variance, inputs.dtype); + var offset_tensor = math_ops.cast(offset, inputs.dtype); + var scale_tensor = math_ops.cast(scale, inputs.dtype); + outputs = nn_impl.batch_normalization(inputs, mean, variance, + offset_tensor, scale_tensor, epsilon); + // If some components of the shape got lost due to adjustments, fix that. + outputs.set_shape(input_shape); + return outputs; } private Tensor _fused_batch_norm(Tensor inputs, Tensor training) diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs index efc8050d..b35d7832 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, new int[] { 1, 2 }, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs index 72285540..3f67e803 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs index 2eba70c7..aa3a92a4 100644 --- a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs +++ b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs @@ -15,9 +15,12 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { - var output = tf_utils.smart_cond(is_training, + if (training == null) + training = false; + + var output = tf_utils.smart_cond(training.Value, () => tf.nn.dropout(inputs, noise_shape: get_noise_shape(inputs), seed: args.Seed, diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs index 7466685f..10609c6b 100644 --- a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs +++ b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { scale = math_ops.cast(args.Scale, args.DType); offset = math_ops.cast(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index 66235198..f376c7d5 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Layers _channels_first = args.DataFormat == "channels_first"; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (_channels_first) { diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index 68bd76af..e8f7d01c 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { var shapes = new List(); shapes.Add(array_ops.shape(inputs)[0]); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs index 3fcf6c99..8314151f 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers inputSpec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { return keras.backend.resize_images(inputs, size[0], size[1], diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs index 7f6ff3e7..101c00c2 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { return keras.backend.spatial_2d_padding(inputs, padding: padding, diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index 1c0470fe..17951219 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (tf.Context.executing_eagerly()) return _defun_call(inputs); diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 3f5ca2b9..5694e8f5 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -55,10 +55,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac - - - - True @@ -70,4 +66,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac + + + + diff --git a/src/TensorFlowNET.Keras/Utils/conv_utils.cs b/src/TensorFlowNET.Keras/Utils/conv_utils.cs index 8d799468..baedca92 100644 --- a/src/TensorFlowNET.Keras/Utils/conv_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/conv_utils.cs @@ -14,7 +14,9 @@ limitations under the License. ******************************************************************************/ +using System; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Utils { @@ -63,5 +65,33 @@ namespace Tensorflow.Keras.Utils return ImageDataFormat.channels_last.ToString(); return value.ToLower(); } + + public static int deconv_output_length(int input_length, + int filter_size, + string padding, + int output_padding = -1, + int stride = 0, + int dilation = 1) + { + // Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1); + + // Infer length if output padding is None, else compute the exact length + int length = -1; + if (output_padding == -1) + { + if (padding == "valid") + length = input_length * stride + max(filter_size - stride, 0); + else if (padding == "full") + length = input_length * stride - (stride + filter_size - 2); + else if (padding == "same") + length = input_length * stride; + } + else + { + throw new NotImplementedException(""); + } + return length; + } } } diff --git a/src/TensorFlowNET.Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Keras/Utils/tf_utils.cs index de542270..b144ec9f 100644 --- a/src/TensorFlowNET.Keras/Utils/tf_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/tf_utils.cs @@ -37,6 +37,17 @@ namespace Tensorflow.Keras.Utils return true; } + public static Tensor[] smart_cond(IVariableV1 pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return control_flow_ops.cond(pred.AsTensor(), + true_fn: true_fn, + false_fn: false_fn, + name: name); + } + public static Tensor[] smart_cond(Tensor pred, Func true_fn = null, Func false_fn = null,