@@ -0,0 +1,15 @@ | |||||
using Newtonsoft.Json; | |||||
namespace Tensorflow.Keras.ArgsDefinition; | |||||
public class NormalizationArgs : PreprocessingLayerArgs | |||||
{ | |||||
[JsonProperty("axis")] | |||||
public Axis? Axis { get; set; } | |||||
[JsonProperty("mean")] | |||||
public float? Mean { get; set; } | |||||
[JsonProperty("variance")] | |||||
public float? Variance { get; set; } | |||||
public bool Invert { get; set; } = false; | |||||
} |
@@ -23,5 +23,6 @@ namespace Tensorflow.Keras | |||||
TensorShapeConfig BuildInputShape { get; } | TensorShapeConfig BuildInputShape { get; } | ||||
TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
int count_params(); | int count_params(); | ||||
void adapt(Tensor data, int? batch_size = null, int? steps = null); | |||||
} | } | ||||
} | } |
@@ -156,6 +156,7 @@ namespace Tensorflow.Keras.Layers | |||||
IInitializer beta_initializer = null, | IInitializer beta_initializer = null, | ||||
IInitializer gamma_initializer = null); | IInitializer gamma_initializer = null); | ||||
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||||
public ILayer LeakyReLU(float alpha = 0.3f); | public ILayer LeakyReLU(float alpha = 0.3f); | ||||
public ILayer LSTM(int units, | public ILayer LSTM(int units, | ||||
@@ -9,6 +9,9 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
public static long GetSize(Shape shape) | public static long GetSize(Shape shape) | ||||
{ | { | ||||
if (shape.IsNull) | |||||
return 0; | |||||
// scalar | // scalar | ||||
if (shape.ndim == 0) | if (shape.ndim == 0) | ||||
return 1; | return 1; | ||||
@@ -159,5 +159,10 @@ namespace Tensorflow | |||||
} | } | ||||
public Trackable GetTrackable() { throw new NotImplementedException(); } | public Trackable GetTrackable() { throw new NotImplementedException(); } | ||||
public void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -16,6 +16,7 @@ | |||||
using Serilog; | using Serilog; | ||||
using Serilog.Core; | using Serilog.Core; | ||||
using System.Reflection; | |||||
using System.Threading; | using System.Threading; | ||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
@@ -52,7 +53,29 @@ namespace Tensorflow | |||||
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | ||||
public IEagerRunner Runner => _runner.Value; | public IEagerRunner Runner => _runner.Value; | ||||
public IKerasApi keras { get; set; } | |||||
private IKerasApi _keras; | |||||
public IKerasApi keras | |||||
{ | |||||
get | |||||
{ | |||||
if (_keras != null) | |||||
{ | |||||
return _keras; | |||||
} | |||||
var k = Assembly.Load("Tensorflow.Keras"); | |||||
var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi))); | |||||
if (cls != null) | |||||
{ | |||||
_keras = Activator.CreateInstance(cls) as IKerasApi; | |||||
return _keras; | |||||
} | |||||
else | |||||
{ | |||||
throw new Exception("Can't find keras library."); | |||||
} | |||||
} | |||||
} | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
@@ -344,5 +344,10 @@ namespace Tensorflow.Keras.Engine | |||||
public virtual IKerasConfig get_config() | public virtual IKerasConfig get_config() | ||||
=> args; | => args; | ||||
public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -20,10 +20,6 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
private static KerasInterface _instance = null; | private static KerasInterface _instance = null; | ||||
private static readonly object _lock = new object(); | private static readonly object _lock = new object(); | ||||
private KerasInterface() | |||||
{ | |||||
Tensorflow.Binding.tf.keras = this; | |||||
} | |||||
public static KerasInterface Instance | public static KerasInterface Instance | ||||
{ | { | ||||
@@ -872,5 +872,14 @@ namespace Tensorflow.Keras.Layers | |||||
Sparse = sparse, | Sparse = sparse, | ||||
CountWeights = count_weights | CountWeights = count_weights | ||||
}); | }); | ||||
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false) | |||||
=> new Normalization(new NormalizationArgs | |||||
{ | |||||
Axis = axis, | |||||
Mean = mean, | |||||
Variance = variance, | |||||
Invert = invert | |||||
}); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,173 @@ | |||||
/***************************************************************************** | |||||
Copyright 2023 Haiping Chen. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public class Normalization : PreprocessingLayer | |||||
{ | |||||
NormalizationArgs _args; | |||||
int[] axis; | |||||
int[] _reduce_axis; | |||||
IVariableV1 adapt_mean, adapt_variance, count; | |||||
Tensor mean, variance; | |||||
Shape _broadcast_shape; | |||||
float? input_mean, input_variance; | |||||
TF_DataType compute_dtype = tf.float32; | |||||
public Normalization(NormalizationArgs args) : base(args) | |||||
{ | |||||
_args = args; | |||||
if (args.Axis == null) | |||||
{ | |||||
axis = new int[0]; | |||||
} | |||||
else | |||||
{ | |||||
axis = args.Axis.axis; | |||||
} | |||||
input_mean = args.Mean; | |||||
input_variance = args.Variance; | |||||
} | |||||
public override void build(Shape input_shape) | |||||
{ | |||||
base.build(input_shape); | |||||
var ndim = input_shape.ndim; | |||||
foreach (var (idx, x) in enumerate(axis)) | |||||
if (x < 0) | |||||
axis[idx] = ndim + x; | |||||
var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray(); | |||||
_reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); | |||||
var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); | |||||
// Broadcast any reduced axes. | |||||
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); | |||||
var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); | |||||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||||
var param_shape = input_shape; | |||||
if(input_mean == null) | |||||
{ | |||||
adapt_mean = add_weight("mean", | |||||
mean_and_var_shape, | |||||
dtype: tf.float32, | |||||
initializer: tf.zeros_initializer, | |||||
trainable: false); | |||||
adapt_variance = add_weight("variance", | |||||
mean_and_var_shape, | |||||
dtype: tf.float32, | |||||
initializer: tf.ones_initializer, | |||||
trainable: false); | |||||
count = add_weight("count", | |||||
Shape.Scalar, | |||||
dtype: tf.int64, | |||||
initializer: tf.zeros_initializer, | |||||
trainable: false); | |||||
finalize_state(); | |||||
} | |||||
else | |||||
{ | |||||
mean = input_mean * np.ones(mean_and_var_shape); | |||||
variance = input_variance * np.ones(mean_and_var_shape); | |||||
mean = tf.reshape(mean, _broadcast_shape); | |||||
variance = tf.reshape(variance, _broadcast_shape); | |||||
mean = tf.cast(mean, compute_dtype); | |||||
variance = tf.cast(variance, compute_dtype); | |||||
} | |||||
} | |||||
public override void reset_state() | |||||
{ | |||||
if (input_mean != null && !built) | |||||
{ | |||||
return; | |||||
} | |||||
adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor())); | |||||
adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor())); | |||||
count.assign(tf.zeros_like(count.AsTensor())); | |||||
} | |||||
public override void finalize_state() | |||||
{ | |||||
if (input_mean != null && !built) | |||||
{ | |||||
return; | |||||
} | |||||
mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape); | |||||
variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape); | |||||
} | |||||
public override void update_state(Tensor data) | |||||
{ | |||||
data = tf.cast(data, adapt_mean.dtype); | |||||
var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis); | |||||
var batch_shape = tf.shape(data, out_type: count.dtype); | |||||
var batch_count = constant_op.constant(1L); | |||||
if (_reduce_axis != null) | |||||
{ | |||||
var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis)); | |||||
batch_count = tf.reduce_prod(batch_reduce_shape); | |||||
} | |||||
var total_count = batch_count + count.AsTensor(); | |||||
var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast( | |||||
total_count, dtype: compute_dtype); | |||||
var existing_weight = 1.0 - batch_weight; | |||||
var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight; | |||||
var total_variance = ( | |||||
adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean) | |||||
) * existing_weight + ( | |||||
batch_variance + tf.square(batch_mean - total_mean) | |||||
) * batch_weight; | |||||
adapt_mean.assign(total_mean); | |||||
adapt_variance.assign(total_variance); | |||||
count.assign(total_count); | |||||
} | |||||
public override Shape ComputeOutputShape(Shape input_shape) | |||||
{ | |||||
return input_shape; | |||||
} | |||||
public override void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||||
{ | |||||
base.adapt(data, batch_size: batch_size, steps: steps); | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
{ | |||||
if (_args.Invert) | |||||
{ | |||||
return mean + ( | |||||
inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon()) | |||||
); | |||||
} | |||||
else | |||||
{ | |||||
return (inputs - mean) / tf.maximum( | |||||
tf.sqrt(variance), keras.backend.epsilon()); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -3,14 +3,95 @@ using System.Collections.Generic; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Engine.DataAdapters; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
public class PreprocessingLayer : Layer | public class PreprocessingLayer : Layer | ||||
{ | { | ||||
bool _is_compiled; | |||||
bool _is_adapted; | |||||
IVariableV1 _steps_per_execution; | |||||
PreprocessingLayerArgs _args; | |||||
public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) | public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) | ||||
{ | { | ||||
_args = args; | |||||
} | |||||
public override void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||||
{ | |||||
if (!_is_compiled) | |||||
{ | |||||
compile(); | |||||
} | |||||
if (built) | |||||
{ | |||||
reset_state(); | |||||
} | |||||
var data_handler = new DataHandler(new DataHandlerArgs | |||||
{ | |||||
X = new Tensors(data), | |||||
BatchSize = _args.BatchSize, | |||||
Epochs = 1, | |||||
StepsPerExecution = _steps_per_execution | |||||
}); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
foreach (var _ in data_handler.steps()) | |||||
{ | |||||
run_step(iterator); | |||||
} | |||||
} | |||||
finalize_state(); | |||||
_is_adapted = true; | |||||
} | |||||
private void run_step(OwnedIterator iterator) | |||||
{ | |||||
var data = iterator.next(); | |||||
_adapt_maybe_build(data[0]); | |||||
update_state(data[0]); | |||||
} | |||||
public virtual void reset_state() | |||||
{ | |||||
} | |||||
public virtual void finalize_state() | |||||
{ | |||||
} | |||||
public virtual void update_state(Tensor data) | |||||
{ | |||||
} | |||||
private void _adapt_maybe_build(Tensor data) | |||||
{ | |||||
if (!built) | |||||
{ | |||||
var data_shape = data.shape; | |||||
var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); | |||||
_args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); | |||||
build(data_shape); | |||||
built = true; | |||||
} | |||||
} | |||||
public void compile(bool run_eagerly = false, int steps_per_execution = 1) | |||||
{ | |||||
_steps_per_execution = tf.Variable( | |||||
steps_per_execution, | |||||
dtype: tf.int64, | |||||
aggregation: VariableAggregation.OnlyFirstReplica | |||||
); | |||||
_is_compiled = true; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,5 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -177,6 +177,55 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); | Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); | ||||
} | } | ||||
/// <summary> | |||||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void Normalization() | |||||
{ | |||||
// Calculate a global mean and variance by analyzing the dataset in adapt(). | |||||
var adapt_data = np.array(new[] { 1f, 2f, 3f, 4f, 5f }); | |||||
var input_data = np.array(new[] { 1f, 2f, 3f }); | |||||
var layer = tf.keras.layers.Normalization(axis: null); | |||||
layer.adapt(adapt_data); | |||||
var x = layer.Apply(input_data); | |||||
Assert.AreEqual(x.numpy(), new[] { -1.4142135f, -0.70710677f, 0f }); | |||||
// Calculate a mean and variance for each index on the last axis. | |||||
adapt_data = np.array(new[,] | |||||
{ | |||||
{ 0, 7, 4 }, | |||||
{ 2, 9, 6 }, | |||||
{ 0, 7, 4 }, | |||||
{ 2, 9, 6 } | |||||
}, dtype: tf.float32); | |||||
input_data = np.array(new[,] { { 0, 7, 4 } }, dtype: tf.float32); | |||||
layer = tf.keras.layers.Normalization(axis: -1); | |||||
layer.adapt(adapt_data); | |||||
x = layer.Apply(input_data); | |||||
Equal(x.numpy().ToArray<float>(), new[] { -1f, -1f, -1f }); | |||||
// Pass the mean and variance directly. | |||||
input_data = np.array(new[,] { { 1f }, { 2f }, { 3f } }, dtype: tf.float32); | |||||
layer = tf.keras.layers.Normalization(mean: 3f, variance: 2f); | |||||
x = layer.Apply(input_data); | |||||
Equal(x.numpy().ToArray<float>(), new[] { -1.4142135f, -0.70710677f, 0f }); | |||||
// Use the layer to de-normalize inputs (after adapting the layer). | |||||
adapt_data = np.array(new[,] | |||||
{ | |||||
{ 0, 7, 4 }, | |||||
{ 2, 9, 6 }, | |||||
{ 0, 7, 4 }, | |||||
{ 2, 9, 6 } | |||||
}, dtype: tf.float32); | |||||
input_data = np.array(new[,] { { 1, 2, 3 } }, dtype: tf.float32); | |||||
layer = tf.keras.layers.Normalization(axis: -1, invert: true); | |||||
layer.adapt(adapt_data); | |||||
x = layer.Apply(input_data); | |||||
Equal(x.numpy().ToArray<float>(), new[] { -2f, -10f, -8f }); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding | /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding | ||||
/// </summary> | /// </summary> | ||||