diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 6b144534..ffb84d4f 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -124,10 +124,10 @@ namespace Tensorflow => gen_nn_ops.relu(features, name); public Tensor[] fused_batch_norm(Tensor x, - IVariableV1 scale, - IVariableV1 offset, - IVariableV1 mean = null, - IVariableV1 variance = null, + Tensor scale, + Tensor offset, + Tensor mean = null, + Tensor variance = null, float epsilon = 0.001f, string data_format = "NHWC", bool is_training = true, diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index a64713ae..be7bf670 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -19,7 +19,6 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; using Tensorflow.Gradients; -using static Tensorflow.Binding; namespace Tensorflow { @@ -49,14 +48,48 @@ namespace Tensorflow RegisterGradientFunction(m.GetCustomAttribute().Name, (oper, out_grads) => { - tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); - var results = g.InvokeMember(m.Name, - BindingFlags.InvokeMethod, - null, - null, - args: new object[] { oper, out_grads }) as Tensor[]; - foreach (var result in results.Where(x => x != null)) - tf.Logger.Debug($"Gradient: {result.name} {result.shape}"); + // tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); + + var results = m.Name switch + { + "_AddGrad" => math_grad._AddGrad(oper, out_grads), + "_AddV2Grad" => math_grad._AddV2Grad(oper, out_grads), + "_BiasAddGrad" => nn_grad._BiasAddGrad(oper, out_grads), + "_CastGrad" => math_grad._CastGrad(oper, out_grads), + "_ConcatGradV2" => array_grad._ConcatGradV2(oper, out_grads), + "_Conv2DGrad" => nn_grad._Conv2DGrad(oper, out_grads), + "_ExpandDimsGrad" => array_grad._ExpandDimsGrad(oper, out_grads), + "_ExpGrad" => math_grad._ExpGrad(oper, out_grads), + "_FusedBatchNormV3Grad" => nn_grad._FusedBatchNormV3Grad(oper, out_grads), + "_IdGrad" => math_grad._IdGrad(oper, out_grads), + "_LeakyReluGrad" => nn_grad._LeakyReluGrad(oper, out_grads), + "_Log1pGrad" => math_grad._Log1pGrad(oper, out_grads), + "_MaximumGrad" => math_grad._MaximumGrad(oper, out_grads), + "_MeanGrad" => math_grad._MeanGrad(oper, out_grads), + "_MinimumGrad" => math_grad._MinimumGrad(oper, out_grads), + "_MulGrad" => math_grad._MulGrad(oper, out_grads), + "_NegGrad" => math_grad._NegGrad(oper, out_grads), + "_PadGrad" => array_grad._PadGrad(oper, out_grads), + "_PowGrad" => math_grad._PowGrad(oper, out_grads), + "_RealDivGrad" => math_grad._RealDivGrad(oper, out_grads), + "_ReadGrad" => resource_variable_grad._ReadGrad(oper, out_grads), + "_ReshapeGrad" => array_grad._ReshapeGrad(oper, out_grads), + "_ResizeNearestNeighborGrad" => image_grad._ResizeNearestNeighborGrad(oper, out_grads), + "_SelectGrad" => math_grad._SelectGrad(oper, out_grads), + "_SigmoidGrad" => math_grad._SigmoidGrad(oper, out_grads), + "_SumGrad" => math_grad._SumGrad(oper, out_grads), + "_SubGrad" => math_grad._SubGrad(oper, out_grads), + "_StridedSliceGrad" => array_grad._StridedSliceGrad(oper, out_grads), + _ => g.InvokeMember(m.Name, + BindingFlags.InvokeMethod, + null, + null, + args: new object[] { oper, out_grads }) as Tensor[] + }; + + // foreach (var result in results.Where(x => x != null)) + // tf.Logger.Debug($"Gradient: {result.name} {result.shape}"); + return results; } ); diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs index 6e3a4c76..2aa327b5 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs @@ -17,6 +17,10 @@ namespace Tensorflow.NumPy float val => GetAtIndex(0) == val, double val => GetAtIndex(0) == val, string val => StringData(0) == val, + int[] val => ToArray().SequenceEqual(val), + long[] val => ToArray().SequenceEqual(val), + float[] val => ToArray().SequenceEqual(val), + double[] val => ToArray().SequenceEqual(val), NDArray val => Equals(this, val), _ => base.Equals(obj) }; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 5b09810e..31ac8650 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -191,10 +191,10 @@ namespace Tensorflow.Operations } public static Tensors fused_batch_norm_v3(Tensor x, - IVariableV1 scale, - IVariableV1 offset, - IVariableV1 mean, - IVariableV1 variance, + Tensor scale, + Tensor offset, + Tensor mean, + Tensor variance, float epsilon = 0.0001f, float exponential_avg_factor = 1.0f, string data_format = "NHWC", diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 7e2ed36f..d24e81ef 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -150,20 +150,18 @@ namespace Tensorflow /// /// public static Tensor[] fused_batch_norm(Tensor x, - IVariableV1 scale, - IVariableV1 offset, - IVariableV1 mean, - IVariableV1 variance, + Tensor scale, + Tensor offset, + Tensor mean = null, + Tensor variance = null, float epsilon = 0.001f, string data_format = "NHWC", bool is_training = true, string name = null, float exponential_avg_factor = 1.0f) { - /*if (mean == null) - mean = constant_op.constant(new float[0]); - if (variance == null) - variance = constant_op.constant(new float[0]);*/ + mean = mean ?? constant_op.constant(new float[0]); + variance = variance ?? constant_op.constant(new float[0]); var min_epsilon = 1.001e-5f; epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 2c3ea4fd..b9cafda1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -29,6 +29,8 @@ namespace Tensorflow { get { + if (Length == 1) + return items[0][index]; return items[index]; } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index 6fb244b2..da8e8c03 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -214,10 +214,10 @@ namespace Tensorflow.Keras.Layers { return tf.nn.fused_batch_norm( inputs, - gamma, - beta, - mean: moving_mean, - variance: moving_variance, + gamma.AsTensor(), + beta.AsTensor(), + mean: moving_mean.AsTensor(), + variance: moving_variance.AsTensor(), epsilon: epsilon, is_training: true, data_format: _data_format, @@ -228,10 +228,10 @@ namespace Tensorflow.Keras.Layers { return tf.nn.fused_batch_norm( inputs, - gamma, - beta, - mean: moving_mean, - variance: moving_variance, + gamma.AsTensor(), + beta.AsTensor(), + mean: moving_mean.AsTensor(), + variance: moving_variance.AsTensor(), epsilon: epsilon, is_training: false, data_format: _data_format); diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index 2fd56afb..51c6423c 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { - Tensors outputs = null; + Tensor outputs = null; var inputs_dtype = inputs.dtype.as_base_dtype(); var input_shape = inputs.shape; var ndims = len(input_shape); @@ -109,6 +109,13 @@ namespace Tensorflow.Keras.Layers foreach (var dim in axis) broadcast_shape[dim] = input_shape.as_int_list()[dim]; + Func _broadcast = v => + { + if (v.shape.ndim != ndims && !axis.SequenceEqual(new int[] { ndims - 1 })) + return tf.reshape(v.AsTensor(), broadcast_shape); + return v.AsTensor(); + }; + if (_fused) { var tensor_shape = tf.shape(inputs); @@ -127,18 +134,28 @@ namespace Tensorflow.Keras.Layers var scale = tf.ones(new Shape((int)pre_dim), dtype: DType); var offset = tf.zeros(new Shape((int)pre_dim), dtype: DType); - /*outputs = tf.nn.fused_batch_norm( + outputs = tf.nn.fused_batch_norm( inputs, scale: scale, offset: offset, epsilon: epsilon, - data_format: "NCHW");*/ + data_format: "NCHW")[0]; + + outputs = tf.reshape(outputs, tensor_shape); + + (scale, offset) = (_broadcast(gamma), _broadcast(beta)); + + outputs = outputs * tf.cast(scale, outputs.dtype); + outputs = outputs + tf.cast(offset, outputs.dtype); } else { } + // If some components of the shape got lost due to adjustments, fix that. + outputs.shape = input_shape; + return outputs; } } diff --git a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs index 17268c30..91dc84b2 100644 --- a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs @@ -152,7 +152,6 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - Assert.IsNull(tf.peak_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index b45b9731..3aeabef5 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -5,6 +5,7 @@ using Tensorflow; using Tensorflow.Keras; using static Tensorflow.Binding; using static Tensorflow.KerasApi; +using System.Linq; namespace TensorFlowNET.Keras.UnitTest { @@ -86,7 +87,7 @@ namespace TensorFlowNET.Keras.UnitTest var emb = keras.layers.Embedding(256, 12, input_length: 4); var input_array = np.arange(12).reshape((3, 4)).astype(np.float32); var output = emb.Apply(input_array); - Assert.AreEqual(new Shape(3, 4, 12), output.shape); + Assert.AreEqual((3, 4, 12), output.shape); } /// @@ -159,7 +160,8 @@ namespace TensorFlowNET.Keras.UnitTest var inputs = tf.constant(np.arange(10).reshape((5, 2)) * 10, dtype: tf.float32); var layer = keras.layers.LayerNormalization(axis: 1); var output = layer.Apply(inputs); - // Assert.AreEqual((10, 16, 16, 3), output.shape); + Assert.AreEqual((5, 2), output.shape); + Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); } } }