From 3aa2738570250ab92a370e7e7c60f334c2468256 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Mon, 13 Mar 2023 21:48:05 -0500 Subject: [PATCH] Add layers.Normalization. --- .../Normalization/NormalizationArgs.cs | 15 ++ src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 1 + .../Keras/Layers/ILayersApi.cs | 1 + src/TensorFlowNET.Core/NumPy/ShapeHelper.cs | 3 + .../Operations/NnOps/RNNCell.cs | 5 + src/TensorFlowNET.Core/tensorflow.cs | 25 ++- src/TensorFlowNET.Keras/Engine/Layer.cs | 5 + src/TensorFlowNET.Keras/KerasInterface.cs | 4 - src/TensorFlowNET.Keras/Layers/LayersApi.cs | 9 + .../Layers/Normalization/Normalization.cs | 173 ++++++++++++++++++ .../Preprocessing/PreprocessingLayer.cs | 81 ++++++++ .../EagerModeTestBase.cs | 1 + .../Layers/LayersTest.cs | 49 +++++ 13 files changed, 367 insertions(+), 5 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs new file mode 100644 index 00000000..30c90145 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs @@ -0,0 +1,15 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition; + +public class NormalizationArgs : PreprocessingLayerArgs +{ + [JsonProperty("axis")] + public Axis? Axis { get; set; } + [JsonProperty("mean")] + public float? Mean { get; set; } + [JsonProperty("variance")] + public float? Variance { get; set; } + + public bool Invert { get; set; } = false; +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index d2dfe8c5..2b864f90 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -23,5 +23,6 @@ namespace Tensorflow.Keras TensorShapeConfig BuildInputShape { get; } TF_DataType DType { get; } int count_params(); + void adapt(Tensor data, int? batch_size = null, int? steps = null); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 6b2c38c3..711c38af 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -156,6 +156,7 @@ namespace Tensorflow.Keras.Layers IInitializer beta_initializer = null, IInitializer gamma_initializer = null); + public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false); public ILayer LeakyReLU(float alpha = 0.3f); public ILayer LSTM(int units, diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index 9c9ae7d3..80f056fe 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -9,6 +9,9 @@ namespace Tensorflow.NumPy { public static long GetSize(Shape shape) { + if (shape.IsNull) + return 0; + // scalar if (shape.ndim == 0) return 1; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 4e9369a8..87b595b6 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -159,5 +159,10 @@ namespace Tensorflow } public Trackable GetTrackable() { throw new NotImplementedException(); } + + public void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index e02723b7..67530ddb 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -16,6 +16,7 @@ using Serilog; using Serilog.Core; +using System.Reflection; using System.Threading; using Tensorflow.Contexts; using Tensorflow.Eager; @@ -52,7 +53,29 @@ namespace Tensorflow ThreadLocal _runner = new ThreadLocal(() => new EagerRunner()); public IEagerRunner Runner => _runner.Value; - public IKerasApi keras { get; set; } + private IKerasApi _keras; + public IKerasApi keras + { + get + { + if (_keras != null) + { + return _keras; + } + + var k = Assembly.Load("Tensorflow.Keras"); + var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi))); + if (cls != null) + { + _keras = Activator.CreateInstance(cls) as IKerasApi; + return _keras; + } + else + { + throw new Exception("Can't find keras library."); + } + } + } public tensorflow() { diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 3934950b..0f809cba 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -344,5 +344,10 @@ namespace Tensorflow.Keras.Engine public virtual IKerasConfig get_config() => args; + + public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + + } } } diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 9f1746d8..f7980706 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -20,10 +20,6 @@ namespace Tensorflow.Keras { private static KerasInterface _instance = null; private static readonly object _lock = new object(); - private KerasInterface() - { - Tensorflow.Binding.tf.keras = this; - } public static KerasInterface Instance { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 22fd661d..67a58a59 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -872,5 +872,14 @@ namespace Tensorflow.Keras.Layers Sparse = sparse, CountWeights = count_weights }); + + public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false) + => new Normalization(new NormalizationArgs + { + Axis = axis, + Mean = mean, + Variance = variance, + Invert = invert + }); } } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs new file mode 100644 index 00000000..c23dde69 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs @@ -0,0 +1,173 @@ +/***************************************************************************** + Copyright 2023 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public class Normalization : PreprocessingLayer + { + NormalizationArgs _args; + + int[] axis; + int[] _reduce_axis; + IVariableV1 adapt_mean, adapt_variance, count; + Tensor mean, variance; + Shape _broadcast_shape; + float? input_mean, input_variance; + TF_DataType compute_dtype = tf.float32; + + public Normalization(NormalizationArgs args) : base(args) + { + _args = args; + if (args.Axis == null) + { + axis = new int[0]; + } + else + { + axis = args.Axis.axis; + } + input_mean = args.Mean; + input_variance = args.Variance; + } + + public override void build(Shape input_shape) + { + base.build(input_shape); + var ndim = input_shape.ndim; + foreach (var (idx, x) in enumerate(axis)) + if (x < 0) + axis[idx] = ndim + x; + + var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray(); + _reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); + var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); + // Broadcast any reduced axes. + _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); + var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); + + var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; + var param_shape = input_shape; + + if(input_mean == null) + { + adapt_mean = add_weight("mean", + mean_and_var_shape, + dtype: tf.float32, + initializer: tf.zeros_initializer, + trainable: false); + + adapt_variance = add_weight("variance", + mean_and_var_shape, + dtype: tf.float32, + initializer: tf.ones_initializer, + trainable: false); + + count = add_weight("count", + Shape.Scalar, + dtype: tf.int64, + initializer: tf.zeros_initializer, + trainable: false); + + finalize_state(); + } + else + { + mean = input_mean * np.ones(mean_and_var_shape); + variance = input_variance * np.ones(mean_and_var_shape); + mean = tf.reshape(mean, _broadcast_shape); + variance = tf.reshape(variance, _broadcast_shape); + mean = tf.cast(mean, compute_dtype); + variance = tf.cast(variance, compute_dtype); + } + } + + public override void reset_state() + { + if (input_mean != null && !built) + { + return; + } + adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor())); + adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor())); + count.assign(tf.zeros_like(count.AsTensor())); + } + + public override void finalize_state() + { + if (input_mean != null && !built) + { + return; + } + mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape); + variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape); + } + + public override void update_state(Tensor data) + { + data = tf.cast(data, adapt_mean.dtype); + var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis); + var batch_shape = tf.shape(data, out_type: count.dtype); + + var batch_count = constant_op.constant(1L); + if (_reduce_axis != null) + { + var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis)); + batch_count = tf.reduce_prod(batch_reduce_shape); + } + var total_count = batch_count + count.AsTensor(); + var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast( + total_count, dtype: compute_dtype); + var existing_weight = 1.0 - batch_weight; + var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight; + + var total_variance = ( + adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean) + ) * existing_weight + ( + batch_variance + tf.square(batch_mean - total_mean) + ) * batch_weight; + adapt_mean.assign(total_mean); + adapt_variance.assign(total_variance); + count.assign(total_count); + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + + public override void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + base.adapt(data, batch_size: batch_size, steps: steps); + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + if (_args.Invert) + { + return mean + ( + inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon()) + ); + } + else + { + return (inputs - mean) / tf.maximum( + tf.sqrt(variance), keras.backend.epsilon()); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs index bd86874b..463936a3 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs @@ -3,14 +3,95 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Engine.DataAdapters; namespace Tensorflow.Keras.Layers { public class PreprocessingLayer : Layer { + bool _is_compiled; + bool _is_adapted; + IVariableV1 _steps_per_execution; + PreprocessingLayerArgs _args; public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) { + _args = args; + } + + public override void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + if (!_is_compiled) + { + compile(); + } + + if (built) + { + reset_state(); + } + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = new Tensors(data), + BatchSize = _args.BatchSize, + Epochs = 1, + StepsPerExecution = _steps_per_execution + }); + + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + foreach (var _ in data_handler.steps()) + { + run_step(iterator); + } + } + finalize_state(); + _is_adapted = true; + } + + private void run_step(OwnedIterator iterator) + { + var data = iterator.next(); + _adapt_maybe_build(data[0]); + update_state(data[0]); + } + + public virtual void reset_state() + { + + } + + public virtual void finalize_state() + { + + } + + public virtual void update_state(Tensor data) + { + + } + + private void _adapt_maybe_build(Tensor data) + { + if (!built) + { + var data_shape = data.shape; + var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); + _args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); + build(data_shape); + built = true; + } + } + + public void compile(bool run_eagerly = false, int steps_per_execution = 1) + { + _steps_per_execution = tf.Variable( + steps_per_execution, + dtype: tf.int64, + aggregation: VariableAggregation.OnlyFirstReplica + ); + _is_compiled = true; } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs index 576c641d..ab1db6b0 100644 --- a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs @@ -1,5 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; +using Tensorflow; using Tensorflow.Keras; using static Tensorflow.Binding; diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 78397c8e..03fd4929 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -177,6 +177,55 @@ namespace TensorFlowNET.Keras.UnitTest Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization + /// + [TestMethod] + public void Normalization() + { + // Calculate a global mean and variance by analyzing the dataset in adapt(). + var adapt_data = np.array(new[] { 1f, 2f, 3f, 4f, 5f }); + var input_data = np.array(new[] { 1f, 2f, 3f }); + var layer = tf.keras.layers.Normalization(axis: null); + layer.adapt(adapt_data); + var x = layer.Apply(input_data); + Assert.AreEqual(x.numpy(), new[] { -1.4142135f, -0.70710677f, 0f }); + + // Calculate a mean and variance for each index on the last axis. + adapt_data = np.array(new[,] + { + { 0, 7, 4 }, + { 2, 9, 6 }, + { 0, 7, 4 }, + { 2, 9, 6 } + }, dtype: tf.float32); + input_data = np.array(new[,] { { 0, 7, 4 } }, dtype: tf.float32); + layer = tf.keras.layers.Normalization(axis: -1); + layer.adapt(adapt_data); + x = layer.Apply(input_data); + Equal(x.numpy().ToArray(), new[] { -1f, -1f, -1f }); + + // Pass the mean and variance directly. + input_data = np.array(new[,] { { 1f }, { 2f }, { 3f } }, dtype: tf.float32); + layer = tf.keras.layers.Normalization(mean: 3f, variance: 2f); + x = layer.Apply(input_data); + Equal(x.numpy().ToArray(), new[] { -1.4142135f, -0.70710677f, 0f }); + + // Use the layer to de-normalize inputs (after adapting the layer). + adapt_data = np.array(new[,] + { + { 0, 7, 4 }, + { 2, 9, 6 }, + { 0, 7, 4 }, + { 2, 9, 6 } + }, dtype: tf.float32); + input_data = np.array(new[,] { { 1, 2, 3 } }, dtype: tf.float32); + layer = tf.keras.layers.Normalization(axis: -1, invert: true); + layer.adapt(adapt_data); + x = layer.Apply(input_data); + Equal(x.numpy().ToArray(), new[] { -2f, -10f, -8f }); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding ///