@@ -0,0 +1,15 @@ | |||
using Newtonsoft.Json; | |||
namespace Tensorflow.Keras.ArgsDefinition; | |||
public class NormalizationArgs : PreprocessingLayerArgs | |||
{ | |||
[JsonProperty("axis")] | |||
public Axis? Axis { get; set; } | |||
[JsonProperty("mean")] | |||
public float? Mean { get; set; } | |||
[JsonProperty("variance")] | |||
public float? Variance { get; set; } | |||
public bool Invert { get; set; } = false; | |||
} |
@@ -23,5 +23,6 @@ namespace Tensorflow.Keras | |||
TensorShapeConfig BuildInputShape { get; } | |||
TF_DataType DType { get; } | |||
int count_params(); | |||
void adapt(Tensor data, int? batch_size = null, int? steps = null); | |||
} | |||
} |
@@ -156,6 +156,7 @@ namespace Tensorflow.Keras.Layers | |||
IInitializer beta_initializer = null, | |||
IInitializer gamma_initializer = null); | |||
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
public ILayer LeakyReLU(float alpha = 0.3f); | |||
public ILayer LSTM(int units, | |||
@@ -9,6 +9,9 @@ namespace Tensorflow.NumPy | |||
{ | |||
public static long GetSize(Shape shape) | |||
{ | |||
if (shape.IsNull) | |||
return 0; | |||
// scalar | |||
if (shape.ndim == 0) | |||
return 1; | |||
@@ -159,5 +159,10 @@ namespace Tensorflow | |||
} | |||
public Trackable GetTrackable() { throw new NotImplementedException(); } | |||
public void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -16,6 +16,7 @@ | |||
using Serilog; | |||
using Serilog.Core; | |||
using System.Reflection; | |||
using System.Threading; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
@@ -52,7 +53,29 @@ namespace Tensorflow | |||
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | |||
public IEagerRunner Runner => _runner.Value; | |||
public IKerasApi keras { get; set; } | |||
private IKerasApi _keras; | |||
public IKerasApi keras | |||
{ | |||
get | |||
{ | |||
if (_keras != null) | |||
{ | |||
return _keras; | |||
} | |||
var k = Assembly.Load("Tensorflow.Keras"); | |||
var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi))); | |||
if (cls != null) | |||
{ | |||
_keras = Activator.CreateInstance(cls) as IKerasApi; | |||
return _keras; | |||
} | |||
else | |||
{ | |||
throw new Exception("Can't find keras library."); | |||
} | |||
} | |||
} | |||
public tensorflow() | |||
{ | |||
@@ -344,5 +344,10 @@ namespace Tensorflow.Keras.Engine | |||
public virtual IKerasConfig get_config() | |||
=> args; | |||
public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
{ | |||
} | |||
} | |||
} |
@@ -20,10 +20,6 @@ namespace Tensorflow.Keras | |||
{ | |||
private static KerasInterface _instance = null; | |||
private static readonly object _lock = new object(); | |||
private KerasInterface() | |||
{ | |||
Tensorflow.Binding.tf.keras = this; | |||
} | |||
public static KerasInterface Instance | |||
{ | |||
@@ -872,5 +872,14 @@ namespace Tensorflow.Keras.Layers | |||
Sparse = sparse, | |||
CountWeights = count_weights | |||
}); | |||
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false) | |||
=> new Normalization(new NormalizationArgs | |||
{ | |||
Axis = axis, | |||
Mean = mean, | |||
Variance = variance, | |||
Invert = invert | |||
}); | |||
} | |||
} |
@@ -0,0 +1,173 @@ | |||
/***************************************************************************** | |||
Copyright 2023 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class Normalization : PreprocessingLayer | |||
{ | |||
NormalizationArgs _args; | |||
int[] axis; | |||
int[] _reduce_axis; | |||
IVariableV1 adapt_mean, adapt_variance, count; | |||
Tensor mean, variance; | |||
Shape _broadcast_shape; | |||
float? input_mean, input_variance; | |||
TF_DataType compute_dtype = tf.float32; | |||
public Normalization(NormalizationArgs args) : base(args) | |||
{ | |||
_args = args; | |||
if (args.Axis == null) | |||
{ | |||
axis = new int[0]; | |||
} | |||
else | |||
{ | |||
axis = args.Axis.axis; | |||
} | |||
input_mean = args.Mean; | |||
input_variance = args.Variance; | |||
} | |||
public override void build(Shape input_shape) | |||
{ | |||
base.build(input_shape); | |||
var ndim = input_shape.ndim; | |||
foreach (var (idx, x) in enumerate(axis)) | |||
if (x < 0) | |||
axis[idx] = ndim + x; | |||
var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray(); | |||
_reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); | |||
var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); | |||
// Broadcast any reduced axes. | |||
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); | |||
var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); | |||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||
var param_shape = input_shape; | |||
if(input_mean == null) | |||
{ | |||
adapt_mean = add_weight("mean", | |||
mean_and_var_shape, | |||
dtype: tf.float32, | |||
initializer: tf.zeros_initializer, | |||
trainable: false); | |||
adapt_variance = add_weight("variance", | |||
mean_and_var_shape, | |||
dtype: tf.float32, | |||
initializer: tf.ones_initializer, | |||
trainable: false); | |||
count = add_weight("count", | |||
Shape.Scalar, | |||
dtype: tf.int64, | |||
initializer: tf.zeros_initializer, | |||
trainable: false); | |||
finalize_state(); | |||
} | |||
else | |||
{ | |||
mean = input_mean * np.ones(mean_and_var_shape); | |||
variance = input_variance * np.ones(mean_and_var_shape); | |||
mean = tf.reshape(mean, _broadcast_shape); | |||
variance = tf.reshape(variance, _broadcast_shape); | |||
mean = tf.cast(mean, compute_dtype); | |||
variance = tf.cast(variance, compute_dtype); | |||
} | |||
} | |||
public override void reset_state() | |||
{ | |||
if (input_mean != null && !built) | |||
{ | |||
return; | |||
} | |||
adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor())); | |||
adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor())); | |||
count.assign(tf.zeros_like(count.AsTensor())); | |||
} | |||
public override void finalize_state() | |||
{ | |||
if (input_mean != null && !built) | |||
{ | |||
return; | |||
} | |||
mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape); | |||
variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape); | |||
} | |||
public override void update_state(Tensor data) | |||
{ | |||
data = tf.cast(data, adapt_mean.dtype); | |||
var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis); | |||
var batch_shape = tf.shape(data, out_type: count.dtype); | |||
var batch_count = constant_op.constant(1L); | |||
if (_reduce_axis != null) | |||
{ | |||
var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis)); | |||
batch_count = tf.reduce_prod(batch_reduce_shape); | |||
} | |||
var total_count = batch_count + count.AsTensor(); | |||
var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast( | |||
total_count, dtype: compute_dtype); | |||
var existing_weight = 1.0 - batch_weight; | |||
var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight; | |||
var total_variance = ( | |||
adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean) | |||
) * existing_weight + ( | |||
batch_variance + tf.square(batch_mean - total_mean) | |||
) * batch_weight; | |||
adapt_mean.assign(total_mean); | |||
adapt_variance.assign(total_variance); | |||
count.assign(total_count); | |||
} | |||
public override Shape ComputeOutputShape(Shape input_shape) | |||
{ | |||
return input_shape; | |||
} | |||
public override void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
{ | |||
base.adapt(data, batch_size: batch_size, steps: steps); | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
if (_args.Invert) | |||
{ | |||
return mean + ( | |||
inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon()) | |||
); | |||
} | |||
else | |||
{ | |||
return (inputs - mean) / tf.maximum( | |||
tf.sqrt(variance), keras.backend.epsilon()); | |||
} | |||
} | |||
} | |||
} |
@@ -3,14 +3,95 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class PreprocessingLayer : Layer | |||
{ | |||
bool _is_compiled; | |||
bool _is_adapted; | |||
IVariableV1 _steps_per_execution; | |||
PreprocessingLayerArgs _args; | |||
public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) | |||
{ | |||
_args = args; | |||
} | |||
public override void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
{ | |||
if (!_is_compiled) | |||
{ | |||
compile(); | |||
} | |||
if (built) | |||
{ | |||
reset_state(); | |||
} | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = new Tensors(data), | |||
BatchSize = _args.BatchSize, | |||
Epochs = 1, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
foreach (var _ in data_handler.steps()) | |||
{ | |||
run_step(iterator); | |||
} | |||
} | |||
finalize_state(); | |||
_is_adapted = true; | |||
} | |||
private void run_step(OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
_adapt_maybe_build(data[0]); | |||
update_state(data[0]); | |||
} | |||
public virtual void reset_state() | |||
{ | |||
} | |||
public virtual void finalize_state() | |||
{ | |||
} | |||
public virtual void update_state(Tensor data) | |||
{ | |||
} | |||
private void _adapt_maybe_build(Tensor data) | |||
{ | |||
if (!built) | |||
{ | |||
var data_shape = data.shape; | |||
var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); | |||
_args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); | |||
build(data_shape); | |||
built = true; | |||
} | |||
} | |||
public void compile(bool run_eagerly = false, int steps_per_execution = 1) | |||
{ | |||
_steps_per_execution = tf.Variable( | |||
steps_per_execution, | |||
dtype: tf.int64, | |||
aggregation: VariableAggregation.OnlyFirstReplica | |||
); | |||
_is_compiled = true; | |||
} | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Keras; | |||
using static Tensorflow.Binding; | |||
@@ -177,6 +177,55 @@ namespace TensorFlowNET.Keras.UnitTest | |||
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); | |||
} | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization | |||
/// </summary> | |||
[TestMethod] | |||
public void Normalization() | |||
{ | |||
// Calculate a global mean and variance by analyzing the dataset in adapt(). | |||
var adapt_data = np.array(new[] { 1f, 2f, 3f, 4f, 5f }); | |||
var input_data = np.array(new[] { 1f, 2f, 3f }); | |||
var layer = tf.keras.layers.Normalization(axis: null); | |||
layer.adapt(adapt_data); | |||
var x = layer.Apply(input_data); | |||
Assert.AreEqual(x.numpy(), new[] { -1.4142135f, -0.70710677f, 0f }); | |||
// Calculate a mean and variance for each index on the last axis. | |||
adapt_data = np.array(new[,] | |||
{ | |||
{ 0, 7, 4 }, | |||
{ 2, 9, 6 }, | |||
{ 0, 7, 4 }, | |||
{ 2, 9, 6 } | |||
}, dtype: tf.float32); | |||
input_data = np.array(new[,] { { 0, 7, 4 } }, dtype: tf.float32); | |||
layer = tf.keras.layers.Normalization(axis: -1); | |||
layer.adapt(adapt_data); | |||
x = layer.Apply(input_data); | |||
Equal(x.numpy().ToArray<float>(), new[] { -1f, -1f, -1f }); | |||
// Pass the mean and variance directly. | |||
input_data = np.array(new[,] { { 1f }, { 2f }, { 3f } }, dtype: tf.float32); | |||
layer = tf.keras.layers.Normalization(mean: 3f, variance: 2f); | |||
x = layer.Apply(input_data); | |||
Equal(x.numpy().ToArray<float>(), new[] { -1.4142135f, -0.70710677f, 0f }); | |||
// Use the layer to de-normalize inputs (after adapting the layer). | |||
adapt_data = np.array(new[,] | |||
{ | |||
{ 0, 7, 4 }, | |||
{ 2, 9, 6 }, | |||
{ 0, 7, 4 }, | |||
{ 2, 9, 6 } | |||
}, dtype: tf.float32); | |||
input_data = np.array(new[,] { { 1, 2, 3 } }, dtype: tf.float32); | |||
layer = tf.keras.layers.Normalization(axis: -1, invert: true); | |||
layer.adapt(adapt_data); | |||
x = layer.Apply(input_data); | |||
Equal(x.numpy().ToArray<float>(), new[] { -2f, -10f, -8f }); | |||
} | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding | |||
/// </summary> | |||