Browse Source

Merge branch 'SciSharp:master' into master

tags/v0.100.5-BERT-load
Long GitHub 2 years ago
parent
commit
046b598bc5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 367 additions and 5 deletions
  1. +15
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  4. +3
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  5. +5
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  6. +24
    -1
      src/TensorFlowNET.Core/tensorflow.cs
  7. +5
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs
  8. +0
    -4
      src/TensorFlowNET.Keras/KerasInterface.cs
  9. +9
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  10. +173
    -0
      src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs
  11. +81
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs
  12. +1
    -0
      test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
  13. +49
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 15
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs View File

@@ -0,0 +1,15 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition;

public class NormalizationArgs : PreprocessingLayerArgs
{
[JsonProperty("axis")]
public Axis? Axis { get; set; }
[JsonProperty("mean")]
public float? Mean { get; set; }
[JsonProperty("variance")]
public float? Variance { get; set; }
public bool Invert { get; set; } = false;
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -23,5 +23,6 @@ namespace Tensorflow.Keras
TensorShapeConfig BuildInputShape { get; }
TF_DataType DType { get; }
int count_params();
void adapt(Tensor data, int? batch_size = null, int? steps = null);
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -156,6 +156,7 @@ namespace Tensorflow.Keras.Layers
IInitializer beta_initializer = null,
IInitializer gamma_initializer = null);

public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
public ILayer LeakyReLU(float alpha = 0.3f);

public ILayer LSTM(int units,


+ 3
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -9,6 +9,9 @@ namespace Tensorflow.NumPy
{
public static long GetSize(Shape shape)
{
if (shape.IsNull)
return 0;

// scalar
if (shape.ndim == 0)
return 1;


+ 5
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -159,5 +159,10 @@ namespace Tensorflow
}

public Trackable GetTrackable() { throw new NotImplementedException(); }

public void adapt(Tensor data, int? batch_size = null, int? steps = null)
{
throw new NotImplementedException();
}
}
}

+ 24
- 1
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -16,6 +16,7 @@

using Serilog;
using Serilog.Core;
using System.Reflection;
using System.Threading;
using Tensorflow.Contexts;
using Tensorflow.Eager;
@@ -52,7 +53,29 @@ namespace Tensorflow
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner());
public IEagerRunner Runner => _runner.Value;

public IKerasApi keras { get; set; }
private IKerasApi _keras;
public IKerasApi keras
{
get
{
if (_keras != null)
{
return _keras;
}

var k = Assembly.Load("Tensorflow.Keras");
var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi)));
if (cls != null)
{
_keras = Activator.CreateInstance(cls) as IKerasApi;
return _keras;
}
else
{
throw new Exception("Can't find keras library.");
}
}
}

public tensorflow()
{


+ 5
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -344,5 +344,10 @@ namespace Tensorflow.Keras.Engine

public virtual IKerasConfig get_config()
=> args;

public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null)
{
}
}
}

+ 0
- 4
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -20,10 +20,6 @@ namespace Tensorflow.Keras
{
private static KerasInterface _instance = null;
private static readonly object _lock = new object();
private KerasInterface()
{
Tensorflow.Binding.tf.keras = this;
}

public static KerasInterface Instance
{


+ 9
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -872,5 +872,14 @@ namespace Tensorflow.Keras.Layers
Sparse = sparse,
CountWeights = count_weights
});

public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false)
=> new Normalization(new NormalizationArgs
{
Axis = axis,
Mean = mean,
Variance = variance,
Invert = invert
});
}
}

+ 173
- 0
src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs View File

@@ -0,0 +1,173 @@
/*****************************************************************************
Copyright 2023 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Layers
{
public class Normalization : PreprocessingLayer
{
NormalizationArgs _args;

int[] axis;
int[] _reduce_axis;
IVariableV1 adapt_mean, adapt_variance, count;
Tensor mean, variance;
Shape _broadcast_shape;
float? input_mean, input_variance;
TF_DataType compute_dtype = tf.float32;

public Normalization(NormalizationArgs args) : base(args)
{
_args = args;
if (args.Axis == null)
{
axis = new int[0];
}
else
{
axis = args.Axis.axis;
}
input_mean = args.Mean;
input_variance = args.Variance;
}

public override void build(Shape input_shape)
{
base.build(input_shape);
var ndim = input_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
axis[idx] = ndim + x;

var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray();
_reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray();
var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray();
// Broadcast any reduced axes.
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray());
var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray();

var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
var param_shape = input_shape;

if(input_mean == null)
{
adapt_mean = add_weight("mean",
mean_and_var_shape,
dtype: tf.float32,
initializer: tf.zeros_initializer,
trainable: false);

adapt_variance = add_weight("variance",
mean_and_var_shape,
dtype: tf.float32,
initializer: tf.ones_initializer,
trainable: false);

count = add_weight("count",
Shape.Scalar,
dtype: tf.int64,
initializer: tf.zeros_initializer,
trainable: false);

finalize_state();
}
else
{
mean = input_mean * np.ones(mean_and_var_shape);
variance = input_variance * np.ones(mean_and_var_shape);
mean = tf.reshape(mean, _broadcast_shape);
variance = tf.reshape(variance, _broadcast_shape);
mean = tf.cast(mean, compute_dtype);
variance = tf.cast(variance, compute_dtype);
}
}

public override void reset_state()
{
if (input_mean != null && !built)
{
return;
}
adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor()));
adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor()));
count.assign(tf.zeros_like(count.AsTensor()));
}

public override void finalize_state()
{
if (input_mean != null && !built)
{
return;
}
mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape);
variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape);
}

public override void update_state(Tensor data)
{
data = tf.cast(data, adapt_mean.dtype);
var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis);
var batch_shape = tf.shape(data, out_type: count.dtype);

var batch_count = constant_op.constant(1L);
if (_reduce_axis != null)
{
var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis));
batch_count = tf.reduce_prod(batch_reduce_shape);
}
var total_count = batch_count + count.AsTensor();
var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast(
total_count, dtype: compute_dtype);
var existing_weight = 1.0 - batch_weight;
var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight;

var total_variance = (
adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean)
) * existing_weight + (
batch_variance + tf.square(batch_mean - total_mean)
) * batch_weight;
adapt_mean.assign(total_mean);
adapt_variance.assign(total_variance);
count.assign(total_count);
}

public override Shape ComputeOutputShape(Shape input_shape)
{
return input_shape;
}

public override void adapt(Tensor data, int? batch_size = null, int? steps = null)
{
base.adapt(data, batch_size: batch_size, steps: steps);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (_args.Invert)
{
return mean + (
inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon())
);
}
else
{
return (inputs - mean) / tf.maximum(
tf.sqrt(variance), keras.backend.epsilon());
}
}
}
}

+ 81
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs View File

@@ -3,14 +3,95 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Engine.DataAdapters;

namespace Tensorflow.Keras.Layers
{
public class PreprocessingLayer : Layer
{
bool _is_compiled;
bool _is_adapted;
IVariableV1 _steps_per_execution;
PreprocessingLayerArgs _args;
public PreprocessingLayer(PreprocessingLayerArgs args) : base(args)
{
_args = args;
}

public override void adapt(Tensor data, int? batch_size = null, int? steps = null)
{
if (!_is_compiled)
{
compile();
}

if (built)
{
reset_state();
}

var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(data),
BatchSize = _args.BatchSize,
Epochs = 1,
StepsPerExecution = _steps_per_execution
});

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
foreach (var _ in data_handler.steps())
{
run_step(iterator);
}
}
finalize_state();
_is_adapted = true;
}

private void run_step(OwnedIterator iterator)
{
var data = iterator.next();
_adapt_maybe_build(data[0]);
update_state(data[0]);
}

public virtual void reset_state()
{

}

public virtual void finalize_state()
{

}

public virtual void update_state(Tensor data)
{

}

private void _adapt_maybe_build(Tensor data)
{
if (!built)
{
var data_shape = data.shape;
var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray();
_args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones);
build(data_shape);
built = true;
}
}

public void compile(bool run_eagerly = false, int steps_per_execution = 1)
{
_steps_per_execution = tf.Variable(
steps_per_execution,
dtype: tf.int64,
aggregation: VariableAggregation.OnlyFirstReplica
);

_is_compiled = true;
}
}
}

+ 1
- 0
test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs View File

@@ -1,5 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Keras;
using static Tensorflow.Binding;



+ 49
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -177,6 +177,55 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f }));
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
/// </summary>
[TestMethod]
public void Normalization()
{
// Calculate a global mean and variance by analyzing the dataset in adapt().
var adapt_data = np.array(new[] { 1f, 2f, 3f, 4f, 5f });
var input_data = np.array(new[] { 1f, 2f, 3f });
var layer = tf.keras.layers.Normalization(axis: null);
layer.adapt(adapt_data);
var x = layer.Apply(input_data);
Assert.AreEqual(x.numpy(), new[] { -1.4142135f, -0.70710677f, 0f });

// Calculate a mean and variance for each index on the last axis.
adapt_data = np.array(new[,]
{
{ 0, 7, 4 },
{ 2, 9, 6 },
{ 0, 7, 4 },
{ 2, 9, 6 }
}, dtype: tf.float32);
input_data = np.array(new[,] { { 0, 7, 4 } }, dtype: tf.float32);
layer = tf.keras.layers.Normalization(axis: -1);
layer.adapt(adapt_data);
x = layer.Apply(input_data);
Equal(x.numpy().ToArray<float>(), new[] { -1f, -1f, -1f });

// Pass the mean and variance directly.
input_data = np.array(new[,] { { 1f }, { 2f }, { 3f } }, dtype: tf.float32);
layer = tf.keras.layers.Normalization(mean: 3f, variance: 2f);
x = layer.Apply(input_data);
Equal(x.numpy().ToArray<float>(), new[] { -1.4142135f, -0.70710677f, 0f });

// Use the layer to de-normalize inputs (after adapting the layer).
adapt_data = np.array(new[,]
{
{ 0, 7, 4 },
{ 2, 9, 6 },
{ 0, 7, 4 },
{ 2, 9, 6 }
}, dtype: tf.float32);
input_data = np.array(new[,] { { 1, 2, 3 } }, dtype: tf.float32);
layer = tf.keras.layers.Normalization(axis: -1, invert: true);
layer.adapt(adapt_data);
x = layer.Apply(input_data);
Equal(x.numpy().ToArray<float>(), new[] { -2f, -10f, -8f });
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding
/// </summary>


Loading…
Cancel
Save