diff --git a/README.md b/README.md index 4d87a8da..4593d2cf 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ Import TF.NET and Keras API in your project. using static Tensorflow.Binding; using static Tensorflow.KerasApi; using Tensorflow; -using NumSharp; +using Tensorflow.NumPy; ``` Linear Regression in `Eager` mode: @@ -162,10 +162,9 @@ Linear Regression in `Eager` mode: #r "nuget: TensorFlow.Net" #r "nuget: TensorFlow.Keras" #r "nuget: SciSharp.TensorFlow.Redist" -#r "nuget: NumSharp" -open NumSharp open Tensorflow +open Tensorflow.NumPy open type Tensorflow.Binding open type Tensorflow.KerasApi diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 1dc8a035..f89977b0 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -203,13 +203,6 @@ namespace Tensorflow yield return values[i]; } - public static T New() where T : ITensorFlowObject, new() - { - var instance = new T(); - instance.__init__(); - return instance; - } - [DebuggerStepThrough] public static void tf_with(ITensorFlowObject py, Action action) { diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index 5682f328..04dd7a9c 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -81,7 +81,7 @@ namespace Tensorflow.Eager if (ops.gradientFunctions[op_name] == null) return new Tensor[op_inputs.Length]; - var gradients = ops.gradientFunctions[op_name](new EagerOperation + var op = new EagerOperation { Name = op_name, NumInputs = op_inputs.Length, @@ -90,9 +90,9 @@ namespace Tensorflow.Eager Outputs = op_outputs, SkipInputIndices = unneeded_gradients, Attrs = attrs - }, output_grads); + }; - return gradients; + return ops.gradientFunctions[op_name](op, output_grads); }; bool CouldForwardprop() diff --git a/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs b/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs index 1fc24813..74d01558 100644 --- a/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs +++ b/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs @@ -20,15 +20,8 @@ namespace Tensorflow { public interface ITensorFlowObject : IDisposable { - /// - /// Called when the instance is created. - /// - void __init__(); - void __enter__(); void __exit__(); - - void __del__(); } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs new file mode 100644 index 00000000..13fd98b4 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs @@ -0,0 +1,16 @@ +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LayerNormalizationArgs : LayerArgs + { + public Axis Axis { get; set; } = -1; + public float Epsilon { get; set; } = 1e-3f; + public bool Center { get; set; } = true; + public bool Scale { get; set; } = true; + public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + public IRegularizer BetaRegularizer { get; set; } + public IRegularizer GammaRegularizer { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index c0a71239..1ad13896 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -89,8 +89,8 @@ tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. - - + + diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 7c6fea08..1d9396f4 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -316,7 +316,7 @@ namespace Tensorflow.Keras.Engine var outputs = node.Layer.Apply(layer_inputs, is_training: training ?? false); foreach (var output in outputs.Where(x => x != null)) tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}"); - // Update tensor_dict for next input + // Update tensor_dict for next or later input foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs)) tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 6ffde8ef..8bbc0cc8 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -635,6 +635,21 @@ namespace Tensorflow.Keras.Layers return layer.Apply(inputs); } + public Layer LayerNormalization(Axis? axis, + float epsilon = 1e-3f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null) + => new LayerNormalization(new LayerNormalizationArgs + { + Axis = axis ?? -1, + Epsilon = epsilon, + Center = center, + Scale = scale, + BetaInitializer = beta_initializer ?? tf.zeros_initializer + }); + /// /// Leaky version of a Rectified Linear Unit. /// diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index 1a29badf..6fb244b2 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -218,7 +218,8 @@ namespace Tensorflow.Keras.Layers beta, mean: moving_mean, variance: moving_variance, - epsilon: epsilon, is_training: true, + epsilon: epsilon, + is_training: true, data_format: _data_format, exponential_avg_factor: exponential_avg_factor); }; diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs new file mode 100644 index 00000000..2fd56afb --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -0,0 +1,145 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class LayerNormalization : Layer + { + LayerNormalizationArgs args; + + float epsilon => args.Epsilon; + bool center => args.Center; + bool scale => args.Scale; + bool _fused; + int[] axis; + string _data_format; + Shape kernel_size; + IInitializer beta_initializer => args.BetaInitializer; + IInitializer gamma_initializer => args.GammaInitializer; + IRegularizer gamma_regularizer => args.GammaRegularizer; + IVariableV1 gamma; + IVariableV1 beta; + IVariableV1 moving_mean; + IVariableV1 moving_variance; + + public LayerNormalization(LayerNormalizationArgs args) : base(args) + { + this.args = args; + axis = args.Axis.axis; + } + + protected override void build(Tensors inputs) + { + Shape input_shape = inputs.shape; + var ndims = input_shape.ndim; + foreach (var (idx, x) in enumerate(axis)) + if (x < 0) + axis[idx] = ndims + x; + + var axis_to_dim = new Dictionary(); + foreach (var x in axis) + axis_to_dim[x] = (int)input_shape[x]; + + inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); + var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; + var param_shape = inputSpec.AllAxisDim; + + if (scale) + gamma = add_weight("gamma", + param_shape, + dtype: param_dtype, + initializer: gamma_initializer, + trainable: true); + + if (center) + beta = add_weight("beta", + param_shape, + dtype: param_dtype, + initializer: beta_initializer, + trainable: true); + + _fused = _fused_can_be_used(ndims); + + built = true; + } + + bool _fused_can_be_used(int ndims) + { + var can_use_fused = false; + if (axis.Last() == ndims - 1 && axis.Last() - axis[0] == len(axis) - 1) + can_use_fused = true; + if (epsilon < 1.001e-5 || DType != tf.float32) + can_use_fused = false; + return can_use_fused; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensors outputs = null; + var inputs_dtype = inputs.dtype.as_base_dtype(); + var input_shape = inputs.shape; + var ndims = len(input_shape); + var broadcast_shape = range(ndims).Select(x => 1).ToArray(); + foreach (var dim in axis) + broadcast_shape[dim] = input_shape.as_int_list()[dim]; + + if (_fused) + { + var tensor_shape = tf.shape(inputs); + var pre_dim = tf.constant(1); + var in_dim = tf.constant(1); + foreach (var dim in range(ndims)) + { + var dim_tensor = tensor_shape[dim]; + if (dim < axis[0]) + pre_dim = pre_dim * dim_tensor; + else + in_dim = in_dim * dim_tensor; + } + inputs = tf.reshape(inputs, new object[] { 1, pre_dim, in_dim, 1 }); + + var scale = tf.ones(new Shape((int)pre_dim), dtype: DType); + var offset = tf.zeros(new Shape((int)pre_dim), dtype: DType); + + /*outputs = tf.nn.fused_batch_norm( + inputs, + scale: scale, + offset: offset, + epsilon: epsilon, + data_format: "NCHW");*/ + } + else + { + + } + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 1ae9e6f6..7b0ef5ba 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -60,7 +60,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac - + diff --git a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs index 1577f6f1..17268c30 100644 --- a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs @@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { - var default_graph = tf.peak_default_graph(); + var default_graph = tf.get_default_graph(); var sess_graph = sess.graph; Assert.IsNotNull(default_graph); Assert.IsNotNull(sess_graph); @@ -49,7 +49,7 @@ namespace TensorFlowNET.UnitTest //tf.Session created an other graph using (var sess = tf.Session()) { - var default_graph = tf.peak_default_graph(); + var default_graph = tf.get_default_graph(); var sess_graph = sess.graph; Assert.IsNotNull(default_graph); Assert.IsNotNull(sess_graph); @@ -159,7 +159,8 @@ namespace TensorFlowNET.UnitTest var math = a1 + a2; for (int i = 0; i < 100; i++) { - using (var sess = tf.Session()) + var graph = tf.get_default_graph(); + using (var sess = tf.Session(graph)) { var result = sess.run(math); Assert.AreEqual(result[0], 5f); @@ -171,14 +172,14 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void SessionRun_InsideSession() { - MultiThreadedUnitTestExecuter.Run(1, Core); + MultiThreadedUnitTestExecuter.Run(8, Core); //the core method void Core(int tid) { using (var sess = tf.Session()) { - Assert.IsNotNull(tf.peak_default_graph()); + Assert.IsNotNull(tf.get_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); @@ -200,7 +201,7 @@ namespace TensorFlowNET.UnitTest { using (var sess = tf.Session()) { - Assert.IsNotNull(tf.peak_default_graph()); + Assert.IsNotNull(tf.get_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); diff --git a/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj index 6112fc3b..ab977853 100644 --- a/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj +++ b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj @@ -24,7 +24,7 @@ - + diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 865d7520..b45b9731 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -152,5 +152,14 @@ namespace TensorFlowNET.Keras.UnitTest var output = layer.Apply(inputs); Assert.AreEqual((10, 16, 16, 3), output.shape); } + + [TestMethod] + public void LayerNormalization() + { + var inputs = tf.constant(np.arange(10).reshape((5, 2)) * 10, dtype: tf.float32); + var layer = keras.layers.LayerNormalization(axis: 1); + var output = layer.Apply(inputs); + // Assert.AreEqual((10, 16, 16, 3), output.shape); + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index 04e7a5e7..885b5167 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -14,7 +14,7 @@ - + diff --git a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj index ae809b6c..957a3c92 100644 --- a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj +++ b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj @@ -44,7 +44,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 01a0bfea..2d4c3b18 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -47,8 +47,8 @@ - - + +