Browse Source

Merge pull request #1039 from AsakusaRinne/fix_1036

Fix (#1036) and adjust the keras unittest
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
4e78d3dd05
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 503 additions and 448 deletions
  1. +104
    -8
      src/TensorFlowNET.Core/Operations/math_ops.cs
  2. +4
    -9
      test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs
  3. +1
    -3
      test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
  4. +2
    -5
      test/TensorFlowNET.Keras.UnitTest/GradientTest.cs
  5. +1
    -6
      test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs
  6. +3
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
  7. +6
    -9
      test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
  8. +7
    -9
      test/TensorFlowNET.Keras.UnitTest/Layers/CosineSimilarity.Test.cs
  9. +4
    -6
      test/TensorFlowNET.Keras.UnitTest/Layers/Huber.Test.cs
  10. +2
    -4
      test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs
  11. +35
    -31
      test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Cropping.Test.cs
  12. +1
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs
  13. +38
    -33
      test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Reshaping.Test.cs
  14. +3
    -6
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  15. +7
    -9
      test/TensorFlowNET.Keras.UnitTest/Layers/LogCosh.Test.cs
  16. +3
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsoluteError.Test.cs
  17. +3
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsolutePercentageError.Test.cs
  18. +4
    -7
      test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredError.Test.cs
  19. +3
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredLogarithmicError.Test.cs
  20. +0
    -35
      test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs
  21. +2
    -6
      test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs
  22. +2
    -10
      test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs
  23. +12
    -19
      test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
  24. +37
    -0
      test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs
  25. +6
    -9
      test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
  26. +200
    -0
      test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs
  27. +1
    -2
      test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
  28. +4
    -4
      test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs
  29. +1
    -7
      test/TensorFlowNET.Keras.UnitTest/OutputTest.cs
  30. +7
    -13
      test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
  31. +0
    -176
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs

+ 104
- 8
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -905,13 +905,29 @@ namespace Tensorflow
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes);
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true);
var ab_matmul = matmul(a_reshape, b_reshape);
var dims = new List<int>();
dims.AddRange(a_free_dims);
dims.AddRange(b_free_dims);
if (ab_matmul.shape.Equals(dims))
return ab_matmul;
if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list)
{
var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray();
if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims))
{
return ab_matmul;
}
else
{
return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name);
}
}
else
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name);
{
var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32);
var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32);
var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name);
if(a_free_dims_static is not null && b_free_dims_static is not null)
{
product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray());
}
return product;
}
});
}

@@ -927,14 +943,42 @@ namespace Tensorflow
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(),
Binding.range(0, axe).ToArray());
}
else
else if(axes.rank == 1)
{
if (axes.shape[0] != 2)
{
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
}
(int a_axe, int b_axe) = (axes[0], axes[1]);
return (new[] { a_axe }, new[] { b_axe });
}
else if(axes.rank == 2)
{
if (axes.shape[0] != 2)
{
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
}
int[] a_axes = new int[axes.shape[1]];
int[] b_axes = new int[axes.shape[1]];
for(int i = 0; i < a_axes.Length; i++)
{
a_axes[i] = axes[0, i];
b_axes[i] = axes[1, i];
if (a_axes[i] == -1 || b_axes[i] == -1)
{
throw new ValueError($"Different number of contraction axes `a` and `b`," +
$"{len(a_axes)} != {len(b_axes)}.");
}
}
return (a_axes, b_axes);
}
else
{
throw new ValueError($"Invalid rank {axes.rank} to make tensor dot.");
}
}

static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
{
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple))))
{
@@ -977,6 +1021,58 @@ namespace Tensorflow
var reshaped_a = array_ops.reshape(a_trans, new_shape);
return (reshaped_a, free_dims, free_dims);
}
else
{
int[] free_dims_static;
Tensor converted_shape_a, converted_axes, converted_free;
if (a.shape.ndim != -1)
{
var shape_a = a.shape.as_int_list();
for(int i = 0; i < axes.Length; i++)
{
if (axes[i] < 0)
{
axes[i] += shape_a.Length;
}
}
var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray();

var axes_dims = axes.Select(i => shape_a[i]);
var free_dims = free.Select(i => shape_a[i]).ToArray();
free_dims_static = free_dims;
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
converted_free = ops.convert_to_tensor(free, dtypes.int32, "free");
converted_shape_a = array_ops.shape(a);
}
else
{
free_dims_static = null;
converted_shape_a = array_ops.shape(a);
var rank_a = array_ops.rank(a);
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a);
(converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)),
converted_axes, dtypes.int32);
}
var converted_free_dims = array_ops.gather(converted_shape_a, converted_free);
var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes);
var prod_free_dims = reduce_prod(converted_free_dims);
var prod_axes_dims = reduce_prod(converted_axes_dims);
Tensor reshaped_a;
if (flipped)
{
var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0);
var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims });
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
}
else
{
var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0);
var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims });
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
}
return (reshaped_a, converted_free_dims, free_dims_static);
}

throw new NotImplementedException("_tensordot_reshape");
}


+ 4
- 9
test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs View File

@@ -1,16 +1,11 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using System.Collections.Generic;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine;
using System.Collections.Generic;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;


namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Callbacks
{
[TestClass]
public class EarlystoppingTest
@@ -31,7 +26,7 @@ namespace TensorFlowNET.Keras.UnitTest
layers.Dense(10)
});

model.summary();

model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
@@ -55,7 +50,7 @@ namespace TensorFlowNET.Keras.UnitTest
var callbacks = new List<ICallback>();
callbacks.add(earlystop);

model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks);
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks);
}

}


+ 1
- 3
test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs View File

@@ -1,10 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Keras;
using static Tensorflow.Binding;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest
{
public class EagerModeTestBase
{


+ 2
- 5
test/TensorFlowNET.Keras.UnitTest/GradientTest.cs View File

@@ -1,14 +1,11 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
using System;
using Tensorflow.Keras.Optimizers;

namespace TensorFlowNET.Keras.UnitTest;
namespace Tensorflow.Keras.UnitTest;

[TestClass]
public class GradientTest : EagerModeTestBase


+ 1
- 6
test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs View File

@@ -1,12 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using TensorFlowNET.Keras.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.Keras.UnitTest;
namespace Tensorflow.Keras.UnitTest;

[TestClass]
public class InitializerTest : EagerModeTestBase


+ 3
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs View File

@@ -1,12 +1,10 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow;

namespace TensorFlowNET.Keras.UnitTest {
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class ActivationTest : EagerModeTestBase
{


+ 6
- 9
test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs View File

@@ -1,15 +1,11 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Layers;
using Tensorflow;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class AttentionTest : EagerModeTestBase
@@ -118,7 +114,8 @@ namespace TensorFlowNET.Keras.UnitTest
} }, dtype: np.float32);
var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat");
//attention_layer.concat_score_weight = 1;
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() {
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs()
{
Name = "concat_score_weight",
Shape = (1),
DType = TF_DataType.TF_FLOAT,
@@ -156,7 +153,7 @@ namespace TensorFlowNET.Keras.UnitTest

var query = keras.Input(shape: (4, 8));
var value = keras.Input(shape: (2, 8));
var mask_tensor = keras.Input(shape:(4, 2));
var mask_tensor = keras.Input(shape: (4, 2));
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
attention_layer.Apply(new Tensor[] { query, value, mask_tensor });



+ 7
- 9
test/TensorFlowNET.Keras.UnitTest/Layers/CosineSimilarity.Test.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class CosineSimilarity
@@ -16,7 +14,7 @@ namespace TensorFlowNET.Keras.UnitTest
NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } };

[TestMethod]
public void _Default()
{
//>>> # Using 'auto'/'sum_over_batch_size' reduction type.
@@ -27,7 +25,7 @@ namespace TensorFlowNET.Keras.UnitTest
//>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
//>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2
//-0.5
var loss = keras.losses.CosineSimilarity(axis : 1);
var loss = keras.losses.CosineSimilarity(axis: 1);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)(-0.49999997f), call.numpy());
}
@@ -41,7 +39,7 @@ namespace TensorFlowNET.Keras.UnitTest
//- 0.0999
var loss = keras.losses.CosineSimilarity();
var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f });
Assert.AreEqual((NDArray) (- 0.099999994f), call.numpy());
Assert.AreEqual((NDArray)(-0.099999994f), call.numpy());
}

[TestMethod]
@@ -53,7 +51,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.SUM)
//>>> cosine_loss(y_true, y_pred).numpy()
//- 0.999
var loss = keras.losses.CosineSimilarity(axis: 1,reduction : ReductionV2.SUM);
var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.SUM);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)(-0.99999994f), call.numpy());
}
@@ -67,7 +65,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.NONE)
//>>> cosine_loss(y_true, y_pred).numpy()
//array([-0., -0.999], dtype = float32)
var loss = keras.losses.CosineSimilarity(axis :1, reduction: ReductionV2.NONE);
var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.NONE);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy());
}


+ 4
- 6
test/TensorFlowNET.Keras.UnitTest/Layers/Huber.Test.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class Huber
@@ -16,7 +14,7 @@ namespace TensorFlowNET.Keras.UnitTest
NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } };

[TestMethod]
public void _Default()
{
//>>> # Using 'auto'/'sum_over_batch_size' reduction type.
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.SUM)
//>>> h(y_true, y_pred).numpy()
//0.31
var loss = keras.losses.Huber(reduction : ReductionV2.SUM);
var loss = keras.losses.Huber(reduction: ReductionV2.SUM);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)0.31f, call.numpy());
}


+ 2
- 4
test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs View File

@@ -1,10 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Operations;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class LayersConvolutionTest : EagerModeTestBase
@@ -14,7 +12,7 @@ namespace TensorFlowNET.Keras.UnitTest
{
var filters = 8;

var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear");
var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear");

var x = np.arange(256.0f).reshape((8, 8, 4));
var y = conv.Apply(x);


+ 35
- 31
test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Cropping.Test.cs View File

@@ -1,39 +1,43 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest {
[TestClass]
public class LayersCroppingTest : EagerModeTestBase {
[TestMethod]
public void Cropping1D () {
Shape input_shape = (1, 5, 2);
var x = tf.zeros(input_shape);
var cropping_1d = keras.layers.Cropping1D(new[] { 1, 2 });
var y = cropping_1d.Apply(x);
Assert.AreEqual((1, 2, 2), y.shape);
}
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class LayersCroppingTest : EagerModeTestBase
{
[TestMethod]
public void Cropping1D()
{
Shape input_shape = (1, 5, 2);
var x = tf.zeros(input_shape);
var cropping_1d = keras.layers.Cropping1D(new[] { 1, 2 });
var y = cropping_1d.Apply(x);
Assert.AreEqual((1, 2, 2), y.shape);
}

[TestMethod]
public void Cropping2D () {
Shape input_shape = (1, 5, 6, 1);
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 } });
var x = tf.zeros(input_shape);
var cropping_2d = keras.layers.Cropping2D(cropping);
var y = cropping_2d.Apply(x);
Assert.AreEqual((1, 2, 2, 1), y.shape);
}
[TestMethod]
public void Cropping2D()
{
Shape input_shape = (1, 5, 6, 1);
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 } });
var x = tf.zeros(input_shape);
var cropping_2d = keras.layers.Cropping2D(cropping);
var y = cropping_2d.Apply(x);
Assert.AreEqual((1, 2, 2, 1), y.shape);
}

[TestMethod]
public void Cropping3D () {
Shape input_shape = new Shape(1, 5, 6, 7, 1);
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 }, { 1, 4 } });
var x = tf.zeros(input_shape);
var cropping_3d = keras.layers.Cropping3D(cropping);
var y = cropping_3d.Apply(x);
Assert.AreEqual(new Shape(1, 2, 2, 2, 1), y.shape);
}
}
[TestMethod]
public void Cropping3D()
{
Shape input_shape = new Shape(1, 5, 6, 7, 1);
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 }, { 1, 4 } });
var x = tf.zeros(input_shape);
var cropping_3d = keras.layers.Cropping3D(cropping);
var y = cropping_3d.Apply(x);
Assert.AreEqual(new Shape(1, 2, 2, 2, 1), y.shape);
}
}
}

+ 1
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs View File

@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class LayersMergingTest : EagerModeTestBase


+ 38
- 33
test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Reshaping.Test.cs View File

@@ -1,43 +1,48 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest {
[TestClass]
public class LayersReshapingTest : EagerModeTestBase {
[TestMethod]
public void ZeroPadding2D () {
Shape input_shape = (1, 1, 2, 2);
var x = np.arange(input_shape.size).reshape(input_shape);
var zero_padding_2d = keras.layers.ZeroPadding2D(new[,] { { 1, 0 }, { 1, 0 } });
var y = zero_padding_2d.Apply(x);
Assert.AreEqual((1, 2, 3, 2), y.shape);
}
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class LayersReshapingTest : EagerModeTestBase
{
[TestMethod]
public void ZeroPadding2D()
{
Shape input_shape = (1, 1, 2, 2);
var x = np.arange(input_shape.size).reshape(input_shape);
var zero_padding_2d = keras.layers.ZeroPadding2D(new[,] { { 1, 0 }, { 1, 0 } });
var y = zero_padding_2d.Apply(x);
Assert.AreEqual((1, 2, 3, 2), y.shape);
}

[TestMethod]
public void UpSampling2D () {
Shape input_shape = (2, 2, 1, 3);
var x = np.arange(input_shape.size).reshape(input_shape);
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x);
Assert.AreEqual((2, 2, 2, 3), y.shape);
}
[TestMethod]
public void UpSampling2D()
{
Shape input_shape = (2, 2, 1, 3);
var x = np.arange(input_shape.size).reshape(input_shape);
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x);
Assert.AreEqual((2, 2, 2, 3), y.shape);
}

[TestMethod]
public void Reshape () {
var inputs = tf.zeros((10, 5, 20));
var outputs = keras.layers.LeakyReLU().Apply(inputs);
outputs = keras.layers.Reshape((20, 5)).Apply(outputs);
Assert.AreEqual((10, 20, 5), outputs.shape);
}
[TestMethod]
public void Reshape()
{
var inputs = tf.zeros((10, 5, 20));
var outputs = keras.layers.LeakyReLU().Apply(inputs);
outputs = keras.layers.Reshape((20, 5)).Apply(outputs);
Assert.AreEqual((10, 20, 5), outputs.shape);
}

[TestMethod]
public void Permute () {
var inputs = tf.zeros((2, 3, 4, 5));
var outputs = keras.layers.Permute(new int[] { 3, 2, 1 }).Apply(inputs);
Assert.AreEqual((2, 5, 4, 3), outputs.shape);
}
[TestMethod]
public void Permute()
{
var inputs = tf.zeros((2, 3, 4, 5));
var outputs = keras.layers.Permute(new int[] { 3, 2, 1 }).Apply(inputs);
Assert.AreEqual((2, 5, 4, 3), outputs.shape);
}

}
}
}

+ 3
- 6
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -1,13 +1,10 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.Keras;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Linq;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
/// <summary>
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
@@ -235,7 +232,7 @@ namespace TensorFlowNET.Keras.UnitTest
// one-hot
var inputs = np.array(new[] { 3, 2, 0, 1 });
var layer = tf.keras.layers.CategoryEncoding(4);
Tensor output = layer.Apply(inputs);
Assert.AreEqual((4, 4), output.shape);
Assert.IsTrue(output[0].numpy().Equals(new[] { 0, 0, 0, 1f }));


+ 7
- 9
test/TensorFlowNET.Keras.UnitTest/Layers/LogCosh.Test.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class LogCosh
@@ -16,7 +14,7 @@ namespace TensorFlowNET.Keras.UnitTest
NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } };

[TestMethod]
public void _Default()
{
//>>> # Using 'auto'/'sum_over_batch_size' reduction type.
@@ -32,9 +30,9 @@ namespace TensorFlowNET.Keras.UnitTest

public void _Sample_Weight()
{
//>>> # Calling with 'sample_weight'.
//>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy()
//0.087
//>>> # Calling with 'sample_weight'.
//>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy()
//0.087
var loss = keras.losses.LogCosh();
var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f });
Assert.AreEqual((NDArray)0.08675616f, call.numpy());
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.SUM)
//>>> l(y_true, y_pred).numpy()
//0.217
var loss = keras.losses.LogCosh(reduction : ReductionV2.SUM);
var loss = keras.losses.LogCosh(reduction: ReductionV2.SUM);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)0.2168904f, call.numpy());
}


+ 3
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsoluteError.Test.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class MeanAbsoluteError
@@ -50,7 +48,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.SUM)
//>>> mae(y_true, y_pred).numpy()
//1.0
var loss = keras.losses.MeanAbsoluteError( reduction: ReductionV2.SUM);
var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.SUM);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)(1.0f), call.numpy());
}


+ 3
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsolutePercentageError.Test.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class MeanAbsolutePercentageError
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.SUM)
//>>> mape(y_true, y_pred).numpy()
//100.
var loss = keras.losses.MeanAbsolutePercentageError( reduction: ReductionV2.SUM);
var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.SUM);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)(100f), call.numpy());
}


+ 4
- 7
test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredError.Test.cs View File

@@ -1,14 +1,11 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class MeanSquaredErrorTest
public class MeanSquaredErrorTest
{
//https://keras.io/api/losses/regression_losses/#meansquarederror-class

@@ -16,7 +13,7 @@ namespace TensorFlowNET.Keras.UnitTest
private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } };

[TestMethod]
public void Mse_Double()
{
var mse = keras.losses.MeanSquaredError();
@@ -25,7 +22,7 @@ namespace TensorFlowNET.Keras.UnitTest
}

[TestMethod]
public void Mse_Float()
{
NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };


+ 3
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredLogarithmicError.Test.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using Tensorflow;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class MeanSquaredLogarithmicError
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest
//... reduction = tf.keras.losses.Reduction.SUM)
//>>> msle(y_true, y_pred).numpy()
//0.480
var loss = keras.losses.MeanSquaredLogarithmicError( reduction: ReductionV2.SUM);
var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.SUM);
var call = loss.Call(y_true_float, y_pred_float);
Assert.AreEqual((NDArray)(0.48045287f), call.numpy());
}


+ 0
- 35
test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs View File

@@ -1,35 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.Engine;
using System.Diagnostics;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Models;

namespace TensorFlowNET.Keras.UnitTest
{
/// <summary>
/// https://www.tensorflow.org/guide/keras/save_and_serialize
/// </summary>
[TestClass]
public class ModelSaveTest : EagerModeTestBase
{
[TestMethod]
public void GetAndFromConfig()
{
var model = GetFunctionalModel();
var config = model.get_config();
Debug.Assert(config is FunctionalConfig);
var new_model = new ModelsApi().from_config(config as FunctionalConfig);
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
}

IModel GetFunctionalModel()
{
// Create a simple model.
var inputs = keras.Input(shape: 32);
var dense_layer = keras.layers.Dense(1);
var outputs = dense_layer.Apply(inputs);
return keras.Model(inputs, outputs);
}
}
}

+ 2
- 6
test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs View File

@@ -1,12 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Microsoft.VisualBasic;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest.Layers
{
/// <summary>
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
@@ -231,7 +227,7 @@ namespace TensorFlowNET.Keras.UnitTest
public void Max1DPoolingChannelsLast()
{
var x = input_array_1D;
var pool = keras.layers.MaxPooling1D(pool_size:2, strides:1);
var pool = keras.layers.MaxPooling1D(pool_size: 2, strides: 1);
var y = pool.Apply(x);

Assert.AreEqual(4, y.shape[0]);


+ 2
- 10
test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs View File

@@ -1,16 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow;
using Tensorflow.NumPy;
using TensorFlowNET.Keras.UnitTest;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest;
namespace Tensorflow.Keras.UnitTest.Losses;

[TestClass]
public class LossesTest : EagerModeTestBase
@@ -47,7 +39,7 @@ public class LossesTest : EagerModeTestBase
// Using 'none' reduction type.
bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE);
loss = bce.Call(y_true, y_pred);
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy());
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f }, loss.numpy());
}

/// <summary>


+ 12
- 19
test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs View File

@@ -1,15 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest;
namespace Tensorflow.Keras.UnitTest.Layers.Metrics;

[TestClass]
public class MetricsTest : EagerModeTestBase
@@ -40,7 +33,7 @@ public class MetricsTest : EagerModeTestBase
[TestMethod]
public void BinaryAccuracy()
{
var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
var y_true = np.array(new[,] { { 1 }, { 1 }, { 0 }, { 0 } });
var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
var m = tf.keras.metrics.BinaryAccuracy();
m.update_state(y_true, y_pred);
@@ -183,17 +176,17 @@ public class MetricsTest : EagerModeTestBase
public void HammingLoss()
{
// multi-class hamming loss
var y_true = np.array(new[,]
{
{ 1, 0, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 1 },
{ 0, 1, 0, 0 }
var y_true = np.array(new[,]
{
{ 1, 0, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 1 },
{ 0, 1, 0, 0 }
});
var y_pred = np.array(new[,]
{
{ 0.8f, 0.1f, 0.1f, 0.0f },
{ 0.2f, 0.0f, 0.8f, 0.0f },
var y_pred = np.array(new[,]
{
{ 0.8f, 0.1f, 0.1f, 0.0f },
{ 0.2f, 0.0f, 0.8f, 0.0f },
{ 0.05f, 0.05f, 0.1f, 0.8f },
{ 1.0f, 0.0f, 0.0f, 0.0f }
});


+ 37
- 0
test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs View File

@@ -0,0 +1,37 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.UnitTest.Model
{
[TestClass]
public class ModelBuildTest
{
[TestMethod]
public void DenseBuild()
{
// two dimensions input with unknown batchsize
var input = tf.keras.layers.Input((17, 60));
var dense = tf.keras.layers.Dense(64);
var output = dense.Apply(input);
var model = tf.keras.Model(input, output);

// one dimensions input with unknown batchsize
var input_2 = tf.keras.layers.Input((60));
var dense_2 = tf.keras.layers.Dense(64);
var output_2 = dense.Apply(input_2);
var model_2 = tf.keras.Model(input_2, output_2);

// two dimensions input with specified batchsize
var input_3 = tf.keras.layers.Input((17, 60), 8);
var dense_3 = tf.keras.layers.Dense(64);
var output_3 = dense.Apply(input_3);
var model_3 = tf.keras.Model(input_3, output_3);

// one dimensions input with specified batchsize
var input_4 = tf.keras.layers.Input((60), 8);
var dense_4 = tf.keras.layers.Dense(64);
var output_4 = dense.Apply(input_4);
var model_4 = tf.keras.Model(input_4, output_4);
}
}
}

test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs → test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs View File

@@ -1,18 +1,15 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;
namespace Tensorflow.Keras.UnitTest.Model;

[TestClass]
public class SequentialModelLoad
public class ModelLoadTest
{
[TestMethod]
public void SimpleModelFromAutoCompile()
@@ -46,7 +43,7 @@ public class SequentialModelLoad
[TestMethod]
public void AlexnetFromSequential()
{
new SequentialModelSave().AlexnetFromSequential();
new ModelSaveTest().AlexnetFromSequential();
var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
model.summary();

@@ -89,7 +86,7 @@ public class SequentialModelLoad
var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
model.summary();

var classify_model = keras.Sequential(new System.Collections.Generic.List<Tensorflow.Keras.ILayer>()
var classify_model = keras.Sequential(new System.Collections.Generic.List<ILayer>()
{
model,
keras.layers.Flatten(),
@@ -100,7 +97,7 @@ public class SequentialModelLoad
classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var x = np.random.uniform(0, 1, (8, 512, 512, 3));
var y = np.ones((8));
var y = np.ones(8);

classify_model.fit(x, y, batch_size: 4);
}
@@ -110,7 +107,7 @@ public class SequentialModelLoad
public void TestModelBeforeTF2_5()
{
var a = keras.layers;
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model;
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
model.summary();
}
}

+ 200
- 0
test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs View File

@@ -0,0 +1,200 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Diagnostics;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.UnitTest.Model
{
/// <summary>
/// https://www.tensorflow.org/guide/keras/save_and_serialize
/// </summary>
[TestClass]
public class ModelSaveTest : EagerModeTestBase
{
[TestMethod]
public void GetAndFromConfig()
{
var model = GetFunctionalModel();
var config = model.get_config();
Debug.Assert(config is FunctionalConfig);
var new_model = new ModelsApi().from_config(config as FunctionalConfig);
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
}

IModel GetFunctionalModel()
{
// Create a simple model.
var inputs = keras.Input(shape: 32);
var dense_layer = keras.layers.Dense(1);
var outputs = dense_layer.Apply(inputs);
return keras.Model(inputs, outputs);
}

[TestMethod]
public void SimpleModelFromAutoCompile()
{
var inputs = tf.keras.layers.Input((28, 28, 1));
var x = tf.keras.layers.Flatten().Apply(inputs);
x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
x = tf.keras.layers.Dense(units: 10).Apply(x);
var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
var model = tf.keras.Model(inputs, outputs);

model.compile(new Adam(0.001f),
tf.keras.losses.SparseCategoricalCrossentropy(),
new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);

model.save("./pb_simple_compile", save_format: "tf");
}

[TestMethod]
public void SimpleModelFromSequential()
{
var model = keras.Sequential(new List<ILayer>()
{
tf.keras.layers.InputLayer((28, 28, 1)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, "relu"),
tf.keras.layers.Dense(10),
tf.keras.layers.Softmax()
});

model.summary();

model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);

model.save("./pb_simple_sequential", save_format: "tf");
}

[TestMethod]
public void AlexnetFromSequential()
{
var model = keras.Sequential(new List<ILayer>()
{
tf.keras.layers.InputLayer((227, 227, 3)),
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),

tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),

tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096, activation: "relu"),
tf.keras.layers.Dropout(0.5f),

tf.keras.layers.Dense(4096, activation: "relu"),
tf.keras.layers.Dropout(0.5f),

tf.keras.layers.Dense(1000, activation: "linear"),
tf.keras.layers.Softmax(1)
});

model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var num_epochs = 1;
var batch_size = 8;

var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);

model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);

model.save("./alexnet_from_sequential", save_format: "tf");

// The saved model can be test with the following python code:
#region alexnet_python_code
//import pathlib
//import tensorflow as tf

//def func(a):
// return -a

//if __name__ == '__main__':
// model = tf.keras.models.load_model("./pb_alex_sequential")
// model.summary()

// num_classes = 5
// batch_size = 128
// img_height = 227
// img_width = 227
// epochs = 100

// dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
// data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
// data_dir = pathlib.Path(data_dir)

// train_ds = tf.keras.preprocessing.image_dataset_from_directory(
// data_dir,
// validation_split = 0.2,
// subset = "training",
// seed = 123,
// image_size = (img_height, img_width),
// batch_size = batch_size)

// val_ds = tf.keras.preprocessing.image_dataset_from_directory(
// data_dir,
// validation_split = 0.2,
// subset = "validation",
// seed = 123,
// image_size = (img_height, img_width),
// batch_size = batch_size)


// model.compile(optimizer = 'adam',
// loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
// metrics =['accuracy'])

// model.build((None, img_height, img_width, 3))

// history = model.fit(
// train_ds,
// validation_data = val_ds,
// epochs = epochs
// )
#endregion
}
}
}

+ 1
- 2
test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs View File

@@ -1,11 +1,10 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class MultiInputModelTest


+ 4
- 4
test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs View File

@@ -1,12 +1,12 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Threading.Tasks;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Threading.Tasks;
using Tensorflow.NumPy;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class MultiThreads


+ 1
- 7
test/TensorFlowNET.Keras.UnitTest/OutputTest.cs View File

@@ -1,14 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class OutputTest


+ 7
- 13
test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs View File

@@ -1,14 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;
using Tensorflow;
using Tensorflow.Keras.Datasets;

namespace TensorFlowNET.Keras.UnitTest
namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class PreprocessingTests : EagerModeTestBase
@@ -71,8 +65,8 @@ namespace TensorFlowNET.Keras.UnitTest

Assert.AreEqual(28, tokenizer.word_index.Count);

Assert.AreEqual(1, tokenizer.word_index[OOV]);
Assert.AreEqual(8, tokenizer.word_index["worst"]);
Assert.AreEqual(1, tokenizer.word_index[OOV]);
Assert.AreEqual(8, tokenizer.word_index["worst"]);
Assert.AreEqual(13, tokenizer.word_index["number"]);
Assert.AreEqual(17, tokenizer.word_index["were"]);
}
@@ -204,13 +198,13 @@ namespace TensorFlowNET.Keras.UnitTest

for (var i = 0; i < sequences.Count; i++)
for (var j = 0; j < sequences[i].Length; j++)
Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]);
Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]);
}

[TestMethod]
public void TokenizeTextsToSequencesWithOOVPresent()
{
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words:20);
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words: 20);
tokenizer.fit_on_texts(texts);

var sequences = tokenizer.texts_to_sequences(texts);
@@ -255,7 +249,7 @@ namespace TensorFlowNET.Keras.UnitTest
tokenizer.fit_on_texts(texts);

var sequences = tokenizer.texts_to_sequences(texts);
var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15);
var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15);

Assert.AreEqual(4, padded.dims[0]);
Assert.AreEqual(15, padded.dims[1]);
@@ -348,7 +342,7 @@ namespace TensorFlowNET.Keras.UnitTest

Assert.AreEqual(27, tokenizer.word_index.Count);

var matrix = tokenizer.texts_to_matrix(texts, mode:"count");
var matrix = tokenizer.texts_to_matrix(texts, mode: "count");

Assert.AreEqual(texts.Length, matrix.dims[0]);



+ 0
- 176
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -1,176 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.Keras;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

[TestClass]
public class SequentialModelSave
{
[TestMethod]
public void SimpleModelFromAutoCompile()
{
var inputs = tf.keras.layers.Input((28, 28, 1));
var x = tf.keras.layers.Flatten().Apply(inputs);
x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
x = tf.keras.layers.Dense(units: 10).Apply(x);
var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
var model = tf.keras.Model(inputs, outputs);

model.compile(new Adam(0.001f),
tf.keras.losses.SparseCategoricalCrossentropy(),
new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);

model.save("./pb_simple_compile", save_format: "tf");
}

[TestMethod]
public void SimpleModelFromSequential()
{
Model model = keras.Sequential(new List<ILayer>()
{
tf.keras.layers.InputLayer((28, 28, 1)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, "relu"),
tf.keras.layers.Dense(10),
tf.keras.layers.Softmax()
});

model.summary();

model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);

model.save("./pb_simple_sequential", save_format: "tf");
}

[TestMethod]
public void AlexnetFromSequential()
{
Model model = keras.Sequential(new List<ILayer>()
{
tf.keras.layers.InputLayer((227, 227, 3)),
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),

tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),

tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096, activation: "relu"),
tf.keras.layers.Dropout(0.5f),

tf.keras.layers.Dense(4096, activation: "relu"),
tf.keras.layers.Dropout(0.5f),

tf.keras.layers.Dense(1000, activation: "linear"),
tf.keras.layers.Softmax(1)
});

model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var num_epochs = 1;
var batch_size = 8;

var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);

model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);

model.save("./alexnet_from_sequential", save_format: "tf");

// The saved model can be test with the following python code:
#region alexnet_python_code
//import pathlib
//import tensorflow as tf

//def func(a):
// return -a

//if __name__ == '__main__':
// model = tf.keras.models.load_model("./pb_alex_sequential")
// model.summary()

// num_classes = 5
// batch_size = 128
// img_height = 227
// img_width = 227
// epochs = 100

// dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
// data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
// data_dir = pathlib.Path(data_dir)

// train_ds = tf.keras.preprocessing.image_dataset_from_directory(
// data_dir,
// validation_split = 0.2,
// subset = "training",
// seed = 123,
// image_size = (img_height, img_width),
// batch_size = batch_size)

// val_ds = tf.keras.preprocessing.image_dataset_from_directory(
// data_dir,
// validation_split = 0.2,
// subset = "validation",
// seed = 123,
// image_size = (img_height, img_width),
// batch_size = batch_size)


// model.compile(optimizer = 'adam',
// loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
// metrics =['accuracy'])

// model.build((None, img_height, img_width, 3))

// history = model.fit(
// train_ds,
// validation_data = val_ds,
// epochs = epochs
// )
#endregion
}
}

Loading…
Cancel
Save