Fix (#1036) and adjust the keras unittesttags/v0.100.5-BERT-load
@@ -905,13 +905,29 @@ namespace Tensorflow | |||||
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); | var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); | ||||
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); | var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); | ||||
var ab_matmul = matmul(a_reshape, b_reshape); | var ab_matmul = matmul(a_reshape, b_reshape); | ||||
var dims = new List<int>(); | |||||
dims.AddRange(a_free_dims); | |||||
dims.AddRange(b_free_dims); | |||||
if (ab_matmul.shape.Equals(dims)) | |||||
return ab_matmul; | |||||
if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list) | |||||
{ | |||||
var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray(); | |||||
if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims)) | |||||
{ | |||||
return ab_matmul; | |||||
} | |||||
else | |||||
{ | |||||
return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name); | |||||
} | |||||
} | |||||
else | else | ||||
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name); | |||||
{ | |||||
var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32); | |||||
var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32); | |||||
var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name); | |||||
if(a_free_dims_static is not null && b_free_dims_static is not null) | |||||
{ | |||||
product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray()); | |||||
} | |||||
return product; | |||||
} | |||||
}); | }); | ||||
} | } | ||||
@@ -927,14 +943,42 @@ namespace Tensorflow | |||||
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), | return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), | ||||
Binding.range(0, axe).ToArray()); | Binding.range(0, axe).ToArray()); | ||||
} | } | ||||
else | |||||
else if(axes.rank == 1) | |||||
{ | { | ||||
if (axes.shape[0] != 2) | |||||
{ | |||||
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); | |||||
} | |||||
(int a_axe, int b_axe) = (axes[0], axes[1]); | (int a_axe, int b_axe) = (axes[0], axes[1]); | ||||
return (new[] { a_axe }, new[] { b_axe }); | return (new[] { a_axe }, new[] { b_axe }); | ||||
} | } | ||||
else if(axes.rank == 2) | |||||
{ | |||||
if (axes.shape[0] != 2) | |||||
{ | |||||
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); | |||||
} | |||||
int[] a_axes = new int[axes.shape[1]]; | |||||
int[] b_axes = new int[axes.shape[1]]; | |||||
for(int i = 0; i < a_axes.Length; i++) | |||||
{ | |||||
a_axes[i] = axes[0, i]; | |||||
b_axes[i] = axes[1, i]; | |||||
if (a_axes[i] == -1 || b_axes[i] == -1) | |||||
{ | |||||
throw new ValueError($"Different number of contraction axes `a` and `b`," + | |||||
$"{len(a_axes)} != {len(b_axes)}."); | |||||
} | |||||
} | |||||
return (a_axes, b_axes); | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError($"Invalid rank {axes.rank} to make tensor dot."); | |||||
} | |||||
} | } | ||||
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) | |||||
static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) | |||||
{ | { | ||||
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) | if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) | ||||
{ | { | ||||
@@ -977,6 +1021,58 @@ namespace Tensorflow | |||||
var reshaped_a = array_ops.reshape(a_trans, new_shape); | var reshaped_a = array_ops.reshape(a_trans, new_shape); | ||||
return (reshaped_a, free_dims, free_dims); | return (reshaped_a, free_dims, free_dims); | ||||
} | } | ||||
else | |||||
{ | |||||
int[] free_dims_static; | |||||
Tensor converted_shape_a, converted_axes, converted_free; | |||||
if (a.shape.ndim != -1) | |||||
{ | |||||
var shape_a = a.shape.as_int_list(); | |||||
for(int i = 0; i < axes.Length; i++) | |||||
{ | |||||
if (axes[i] < 0) | |||||
{ | |||||
axes[i] += shape_a.Length; | |||||
} | |||||
} | |||||
var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray(); | |||||
var axes_dims = axes.Select(i => shape_a[i]); | |||||
var free_dims = free.Select(i => shape_a[i]).ToArray(); | |||||
free_dims_static = free_dims; | |||||
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); | |||||
converted_free = ops.convert_to_tensor(free, dtypes.int32, "free"); | |||||
converted_shape_a = array_ops.shape(a); | |||||
} | |||||
else | |||||
{ | |||||
free_dims_static = null; | |||||
converted_shape_a = array_ops.shape(a); | |||||
var rank_a = array_ops.rank(a); | |||||
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); | |||||
converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a); | |||||
(converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)), | |||||
converted_axes, dtypes.int32); | |||||
} | |||||
var converted_free_dims = array_ops.gather(converted_shape_a, converted_free); | |||||
var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes); | |||||
var prod_free_dims = reduce_prod(converted_free_dims); | |||||
var prod_axes_dims = reduce_prod(converted_axes_dims); | |||||
Tensor reshaped_a; | |||||
if (flipped) | |||||
{ | |||||
var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0); | |||||
var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims }); | |||||
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); | |||||
} | |||||
else | |||||
{ | |||||
var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0); | |||||
var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims }); | |||||
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); | |||||
} | |||||
return (reshaped_a, converted_free_dims, free_dims_static); | |||||
} | |||||
throw new NotImplementedException("_tensordot_reshape"); | throw new NotImplementedException("_tensordot_reshape"); | ||||
} | } | ||||
@@ -1,16 +1,11 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.Keras.UnitTest.Helpers; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Optimizers; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.Callbacks; | using Tensorflow.Keras.Callbacks; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using System.Collections.Generic; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow.Keras; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Callbacks | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class EarlystoppingTest | public class EarlystoppingTest | ||||
@@ -31,7 +26,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
layers.Dense(10) | layers.Dense(10) | ||||
}); | }); | ||||
model.summary(); | model.summary(); | ||||
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | ||||
@@ -55,7 +50,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
var callbacks = new List<ICallback>(); | var callbacks = new List<ICallback>(); | ||||
callbacks.add(earlystop); | callbacks.add(earlystop); | ||||
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks); | |||||
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks); | |||||
} | } | ||||
} | } | ||||
@@ -1,10 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest | |||||
{ | { | ||||
public class EagerModeTestBase | public class EagerModeTestBase | ||||
{ | { | ||||
@@ -1,14 +1,11 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow.NumPy; | |||||
using System; | |||||
using Tensorflow.Keras.Optimizers; | |||||
namespace TensorFlowNET.Keras.UnitTest; | |||||
namespace Tensorflow.Keras.UnitTest; | |||||
[TestClass] | [TestClass] | ||||
public class GradientTest : EagerModeTestBase | public class GradientTest : EagerModeTestBase | ||||
@@ -1,12 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using TensorFlowNET.Keras.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.Keras.UnitTest; | |||||
namespace Tensorflow.Keras.UnitTest; | |||||
[TestClass] | [TestClass] | ||||
public class InitializerTest : EagerModeTestBase | public class InitializerTest : EagerModeTestBase | ||||
@@ -1,12 +1,10 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow; | |||||
namespace TensorFlowNET.Keras.UnitTest { | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | |||||
[TestClass] | [TestClass] | ||||
public class ActivationTest : EagerModeTestBase | public class ActivationTest : EagerModeTestBase | ||||
{ | { | ||||
@@ -1,15 +1,11 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Utils; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Utils; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class AttentionTest : EagerModeTestBase | public class AttentionTest : EagerModeTestBase | ||||
@@ -118,7 +114,8 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
} }, dtype: np.float32); | } }, dtype: np.float32); | ||||
var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat"); | var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat"); | ||||
//attention_layer.concat_score_weight = 1; | //attention_layer.concat_score_weight = 1; | ||||
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { | |||||
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() | |||||
{ | |||||
Name = "concat_score_weight", | Name = "concat_score_weight", | ||||
Shape = (1), | Shape = (1), | ||||
DType = TF_DataType.TF_FLOAT, | DType = TF_DataType.TF_FLOAT, | ||||
@@ -156,7 +153,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
var query = keras.Input(shape: (4, 8)); | var query = keras.Input(shape: (4, 8)); | ||||
var value = keras.Input(shape: (2, 8)); | var value = keras.Input(shape: (2, 8)); | ||||
var mask_tensor = keras.Input(shape:(4, 2)); | |||||
var mask_tensor = keras.Input(shape: (4, 2)); | |||||
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); | var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); | ||||
attention_layer.Apply(new Tensor[] { query, value, mask_tensor }); | attention_layer.Apply(new Tensor[] { query, value, mask_tensor }); | ||||
@@ -1,11 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class CosineSimilarity | public class CosineSimilarity | ||||
@@ -16,7 +14,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } }; | NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } }; | ||||
[TestMethod] | [TestMethod] | ||||
public void _Default() | public void _Default() | ||||
{ | { | ||||
//>>> # Using 'auto'/'sum_over_batch_size' reduction type. | //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | ||||
@@ -27,7 +25,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) | //>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) | ||||
//>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 | //>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 | ||||
//-0.5 | //-0.5 | ||||
var loss = keras.losses.CosineSimilarity(axis : 1); | |||||
var loss = keras.losses.CosineSimilarity(axis: 1); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)(-0.49999997f), call.numpy()); | Assert.AreEqual((NDArray)(-0.49999997f), call.numpy()); | ||||
} | } | ||||
@@ -41,7 +39,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//- 0.0999 | //- 0.0999 | ||||
var loss = keras.losses.CosineSimilarity(); | var loss = keras.losses.CosineSimilarity(); | ||||
var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); | var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); | ||||
Assert.AreEqual((NDArray) (- 0.099999994f), call.numpy()); | |||||
Assert.AreEqual((NDArray)(-0.099999994f), call.numpy()); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -53,7 +51,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.SUM) | //... reduction = tf.keras.losses.Reduction.SUM) | ||||
//>>> cosine_loss(y_true, y_pred).numpy() | //>>> cosine_loss(y_true, y_pred).numpy() | ||||
//- 0.999 | //- 0.999 | ||||
var loss = keras.losses.CosineSimilarity(axis: 1,reduction : ReductionV2.SUM); | |||||
var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.SUM); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)(-0.99999994f), call.numpy()); | Assert.AreEqual((NDArray)(-0.99999994f), call.numpy()); | ||||
} | } | ||||
@@ -67,7 +65,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.NONE) | //... reduction = tf.keras.losses.Reduction.NONE) | ||||
//>>> cosine_loss(y_true, y_pred).numpy() | //>>> cosine_loss(y_true, y_pred).numpy() | ||||
//array([-0., -0.999], dtype = float32) | //array([-0., -0.999], dtype = float32) | ||||
var loss = keras.losses.CosineSimilarity(axis :1, reduction: ReductionV2.NONE); | |||||
var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.NONE); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy()); | Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy()); | ||||
} | } | ||||
@@ -1,11 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class Huber | public class Huber | ||||
@@ -16,7 +14,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } }; | NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } }; | ||||
[TestMethod] | [TestMethod] | ||||
public void _Default() | public void _Default() | ||||
{ | { | ||||
//>>> # Using 'auto'/'sum_over_batch_size' reduction type. | //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | ||||
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.SUM) | //... reduction = tf.keras.losses.Reduction.SUM) | ||||
//>>> h(y_true, y_pred).numpy() | //>>> h(y_true, y_pred).numpy() | ||||
//0.31 | //0.31 | ||||
var loss = keras.losses.Huber(reduction : ReductionV2.SUM); | |||||
var loss = keras.losses.Huber(reduction: ReductionV2.SUM); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)0.31f, call.numpy()); | Assert.AreEqual((NDArray)0.31f, call.numpy()); | ||||
} | } | ||||
@@ -1,10 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow; | |||||
using Tensorflow.Operations; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class LayersConvolutionTest : EagerModeTestBase | public class LayersConvolutionTest : EagerModeTestBase | ||||
@@ -14,7 +12,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
{ | { | ||||
var filters = 8; | var filters = 8; | ||||
var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear"); | |||||
var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear"); | |||||
var x = np.arange(256.0f).reshape((8, 8, 4)); | var x = np.arange(256.0f).reshape((8, 8, 4)); | ||||
var y = conv.Apply(x); | var y = conv.Apply(x); | ||||
@@ -1,39 +1,43 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest { | |||||
[TestClass] | |||||
public class LayersCroppingTest : EagerModeTestBase { | |||||
[TestMethod] | |||||
public void Cropping1D () { | |||||
Shape input_shape = (1, 5, 2); | |||||
var x = tf.zeros(input_shape); | |||||
var cropping_1d = keras.layers.Cropping1D(new[] { 1, 2 }); | |||||
var y = cropping_1d.Apply(x); | |||||
Assert.AreEqual((1, 2, 2), y.shape); | |||||
} | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | |||||
[TestClass] | |||||
public class LayersCroppingTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void Cropping1D() | |||||
{ | |||||
Shape input_shape = (1, 5, 2); | |||||
var x = tf.zeros(input_shape); | |||||
var cropping_1d = keras.layers.Cropping1D(new[] { 1, 2 }); | |||||
var y = cropping_1d.Apply(x); | |||||
Assert.AreEqual((1, 2, 2), y.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Cropping2D () { | |||||
Shape input_shape = (1, 5, 6, 1); | |||||
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 } }); | |||||
var x = tf.zeros(input_shape); | |||||
var cropping_2d = keras.layers.Cropping2D(cropping); | |||||
var y = cropping_2d.Apply(x); | |||||
Assert.AreEqual((1, 2, 2, 1), y.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Cropping2D() | |||||
{ | |||||
Shape input_shape = (1, 5, 6, 1); | |||||
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 } }); | |||||
var x = tf.zeros(input_shape); | |||||
var cropping_2d = keras.layers.Cropping2D(cropping); | |||||
var y = cropping_2d.Apply(x); | |||||
Assert.AreEqual((1, 2, 2, 1), y.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Cropping3D () { | |||||
Shape input_shape = new Shape(1, 5, 6, 7, 1); | |||||
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 }, { 1, 4 } }); | |||||
var x = tf.zeros(input_shape); | |||||
var cropping_3d = keras.layers.Cropping3D(cropping); | |||||
var y = cropping_3d.Apply(x); | |||||
Assert.AreEqual(new Shape(1, 2, 2, 2, 1), y.shape); | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void Cropping3D() | |||||
{ | |||||
Shape input_shape = new Shape(1, 5, 6, 7, 1); | |||||
NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 }, { 1, 4 } }); | |||||
var x = tf.zeros(input_shape); | |||||
var cropping_3d = keras.layers.Cropping3D(cropping); | |||||
var y = cropping_3d.Apply(x); | |||||
Assert.AreEqual(new Shape(1, 2, 2, 2, 1), y.shape); | |||||
} | |||||
} | |||||
} | } |
@@ -1,9 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class LayersMergingTest : EagerModeTestBase | public class LayersMergingTest : EagerModeTestBase | ||||
@@ -1,43 +1,48 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest { | |||||
[TestClass] | |||||
public class LayersReshapingTest : EagerModeTestBase { | |||||
[TestMethod] | |||||
public void ZeroPadding2D () { | |||||
Shape input_shape = (1, 1, 2, 2); | |||||
var x = np.arange(input_shape.size).reshape(input_shape); | |||||
var zero_padding_2d = keras.layers.ZeroPadding2D(new[,] { { 1, 0 }, { 1, 0 } }); | |||||
var y = zero_padding_2d.Apply(x); | |||||
Assert.AreEqual((1, 2, 3, 2), y.shape); | |||||
} | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | |||||
[TestClass] | |||||
public class LayersReshapingTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void ZeroPadding2D() | |||||
{ | |||||
Shape input_shape = (1, 1, 2, 2); | |||||
var x = np.arange(input_shape.size).reshape(input_shape); | |||||
var zero_padding_2d = keras.layers.ZeroPadding2D(new[,] { { 1, 0 }, { 1, 0 } }); | |||||
var y = zero_padding_2d.Apply(x); | |||||
Assert.AreEqual((1, 2, 3, 2), y.shape); | |||||
} | |||||
[TestMethod] | |||||
public void UpSampling2D () { | |||||
Shape input_shape = (2, 2, 1, 3); | |||||
var x = np.arange(input_shape.size).reshape(input_shape); | |||||
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x); | |||||
Assert.AreEqual((2, 2, 2, 3), y.shape); | |||||
} | |||||
[TestMethod] | |||||
public void UpSampling2D() | |||||
{ | |||||
Shape input_shape = (2, 2, 1, 3); | |||||
var x = np.arange(input_shape.size).reshape(input_shape); | |||||
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x); | |||||
Assert.AreEqual((2, 2, 2, 3), y.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Reshape () { | |||||
var inputs = tf.zeros((10, 5, 20)); | |||||
var outputs = keras.layers.LeakyReLU().Apply(inputs); | |||||
outputs = keras.layers.Reshape((20, 5)).Apply(outputs); | |||||
Assert.AreEqual((10, 20, 5), outputs.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Reshape() | |||||
{ | |||||
var inputs = tf.zeros((10, 5, 20)); | |||||
var outputs = keras.layers.LeakyReLU().Apply(inputs); | |||||
outputs = keras.layers.Reshape((20, 5)).Apply(outputs); | |||||
Assert.AreEqual((10, 20, 5), outputs.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Permute () { | |||||
var inputs = tf.zeros((2, 3, 4, 5)); | |||||
var outputs = keras.layers.Permute(new int[] { 3, 2, 1 }).Apply(inputs); | |||||
Assert.AreEqual((2, 5, 4, 3), outputs.shape); | |||||
} | |||||
[TestMethod] | |||||
public void Permute() | |||||
{ | |||||
var inputs = tf.zeros((2, 3, 4, 5)); | |||||
var outputs = keras.layers.Permute(new int[] { 3, 2, 1 }).Apply(inputs); | |||||
Assert.AreEqual((2, 5, 4, 3), outputs.shape); | |||||
} | |||||
} | |||||
} | |||||
} | } |
@@ -1,13 +1,10 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using System.Linq; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | ||||
@@ -235,7 +232,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
// one-hot | // one-hot | ||||
var inputs = np.array(new[] { 3, 2, 0, 1 }); | var inputs = np.array(new[] { 3, 2, 0, 1 }); | ||||
var layer = tf.keras.layers.CategoryEncoding(4); | var layer = tf.keras.layers.CategoryEncoding(4); | ||||
Tensor output = layer.Apply(inputs); | Tensor output = layer.Apply(inputs); | ||||
Assert.AreEqual((4, 4), output.shape); | Assert.AreEqual((4, 4), output.shape); | ||||
Assert.IsTrue(output[0].numpy().Equals(new[] { 0, 0, 0, 1f })); | Assert.IsTrue(output[0].numpy().Equals(new[] { 0, 0, 0, 1f })); | ||||
@@ -1,11 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class LogCosh | public class LogCosh | ||||
@@ -16,7 +14,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } }; | NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } }; | ||||
[TestMethod] | [TestMethod] | ||||
public void _Default() | public void _Default() | ||||
{ | { | ||||
//>>> # Using 'auto'/'sum_over_batch_size' reduction type. | //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | ||||
@@ -32,9 +30,9 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
public void _Sample_Weight() | public void _Sample_Weight() | ||||
{ | { | ||||
//>>> # Calling with 'sample_weight'. | |||||
//>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() | |||||
//0.087 | |||||
//>>> # Calling with 'sample_weight'. | |||||
//>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() | |||||
//0.087 | |||||
var loss = keras.losses.LogCosh(); | var loss = keras.losses.LogCosh(); | ||||
var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); | var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); | ||||
Assert.AreEqual((NDArray)0.08675616f, call.numpy()); | Assert.AreEqual((NDArray)0.08675616f, call.numpy()); | ||||
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.SUM) | //... reduction = tf.keras.losses.Reduction.SUM) | ||||
//>>> l(y_true, y_pred).numpy() | //>>> l(y_true, y_pred).numpy() | ||||
//0.217 | //0.217 | ||||
var loss = keras.losses.LogCosh(reduction : ReductionV2.SUM); | |||||
var loss = keras.losses.LogCosh(reduction: ReductionV2.SUM); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)0.2168904f, call.numpy()); | Assert.AreEqual((NDArray)0.2168904f, call.numpy()); | ||||
} | } | ||||
@@ -1,11 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MeanAbsoluteError | public class MeanAbsoluteError | ||||
@@ -50,7 +48,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.SUM) | //... reduction = tf.keras.losses.Reduction.SUM) | ||||
//>>> mae(y_true, y_pred).numpy() | //>>> mae(y_true, y_pred).numpy() | ||||
//1.0 | //1.0 | ||||
var loss = keras.losses.MeanAbsoluteError( reduction: ReductionV2.SUM); | |||||
var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.SUM); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)(1.0f), call.numpy()); | Assert.AreEqual((NDArray)(1.0f), call.numpy()); | ||||
} | } | ||||
@@ -1,11 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MeanAbsolutePercentageError | public class MeanAbsolutePercentageError | ||||
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.SUM) | //... reduction = tf.keras.losses.Reduction.SUM) | ||||
//>>> mape(y_true, y_pred).numpy() | //>>> mape(y_true, y_pred).numpy() | ||||
//100. | //100. | ||||
var loss = keras.losses.MeanAbsolutePercentageError( reduction: ReductionV2.SUM); | |||||
var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.SUM); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)(100f), call.numpy()); | Assert.AreEqual((NDArray)(100f), call.numpy()); | ||||
} | } | ||||
@@ -1,14 +1,11 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MeanSquaredErrorTest | |||||
public class MeanSquaredErrorTest | |||||
{ | { | ||||
//https://keras.io/api/losses/regression_losses/#meansquarederror-class | //https://keras.io/api/losses/regression_losses/#meansquarederror-class | ||||
@@ -16,7 +13,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } }; | private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } }; | ||||
[TestMethod] | [TestMethod] | ||||
public void Mse_Double() | public void Mse_Double() | ||||
{ | { | ||||
var mse = keras.losses.MeanSquaredError(); | var mse = keras.losses.MeanSquaredError(); | ||||
@@ -25,7 +22,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void Mse_Float() | public void Mse_Float() | ||||
{ | { | ||||
NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | ||||
@@ -1,11 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MeanSquaredLogarithmicError | public class MeanSquaredLogarithmicError | ||||
@@ -49,7 +47,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
//... reduction = tf.keras.losses.Reduction.SUM) | //... reduction = tf.keras.losses.Reduction.SUM) | ||||
//>>> msle(y_true, y_pred).numpy() | //>>> msle(y_true, y_pred).numpy() | ||||
//0.480 | //0.480 | ||||
var loss = keras.losses.MeanSquaredLogarithmicError( reduction: ReductionV2.SUM); | |||||
var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.SUM); | |||||
var call = loss.Call(y_true_float, y_pred_float); | var call = loss.Call(y_true_float, y_pred_float); | ||||
Assert.AreEqual((NDArray)(0.48045287f), call.numpy()); | Assert.AreEqual((NDArray)(0.48045287f), call.numpy()); | ||||
} | } | ||||
@@ -1,35 +0,0 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow.Keras.Engine; | |||||
using System.Diagnostics; | |||||
using static Tensorflow.KerasApi; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Models; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
{ | |||||
/// <summary> | |||||
/// https://www.tensorflow.org/guide/keras/save_and_serialize | |||||
/// </summary> | |||||
[TestClass] | |||||
public class ModelSaveTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void GetAndFromConfig() | |||||
{ | |||||
var model = GetFunctionalModel(); | |||||
var config = model.get_config(); | |||||
Debug.Assert(config is FunctionalConfig); | |||||
var new_model = new ModelsApi().from_config(config as FunctionalConfig); | |||||
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); | |||||
} | |||||
IModel GetFunctionalModel() | |||||
{ | |||||
// Create a simple model. | |||||
var inputs = keras.Input(shape: 32); | |||||
var dense_layer = keras.layers.Dense(1); | |||||
var outputs = dense_layer.Apply(inputs); | |||||
return keras.Model(inputs, outputs); | |||||
} | |||||
} | |||||
} |
@@ -1,12 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using System.Linq; | |||||
using Tensorflow; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Microsoft.VisualBasic; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | ||||
@@ -231,7 +227,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
public void Max1DPoolingChannelsLast() | public void Max1DPoolingChannelsLast() | ||||
{ | { | ||||
var x = input_array_1D; | var x = input_array_1D; | ||||
var pool = keras.layers.MaxPooling1D(pool_size:2, strides:1); | |||||
var pool = keras.layers.MaxPooling1D(pool_size: 2, strides: 1); | |||||
var y = pool.Apply(x); | var y = pool.Apply(x); | ||||
Assert.AreEqual(4, y.shape[0]); | Assert.AreEqual(4, y.shape[0]); | ||||
@@ -1,16 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using TensorFlowNET.Keras.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | |||||
namespace TensorFlowNET.Keras.UnitTest; | |||||
namespace Tensorflow.Keras.UnitTest.Losses; | |||||
[TestClass] | [TestClass] | ||||
public class LossesTest : EagerModeTestBase | public class LossesTest : EagerModeTestBase | ||||
@@ -47,7 +39,7 @@ public class LossesTest : EagerModeTestBase | |||||
// Using 'none' reduction type. | // Using 'none' reduction type. | ||||
bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE); | bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE); | ||||
loss = bce.Call(y_true, y_pred); | loss = bce.Call(y_true, y_pred); | ||||
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy()); | |||||
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f }, loss.numpy()); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -1,15 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | |||||
namespace TensorFlowNET.Keras.UnitTest; | |||||
namespace Tensorflow.Keras.UnitTest.Layers.Metrics; | |||||
[TestClass] | [TestClass] | ||||
public class MetricsTest : EagerModeTestBase | public class MetricsTest : EagerModeTestBase | ||||
@@ -40,7 +33,7 @@ public class MetricsTest : EagerModeTestBase | |||||
[TestMethod] | [TestMethod] | ||||
public void BinaryAccuracy() | public void BinaryAccuracy() | ||||
{ | { | ||||
var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } }); | |||||
var y_true = np.array(new[,] { { 1 }, { 1 }, { 0 }, { 0 } }); | |||||
var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } }); | var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } }); | ||||
var m = tf.keras.metrics.BinaryAccuracy(); | var m = tf.keras.metrics.BinaryAccuracy(); | ||||
m.update_state(y_true, y_pred); | m.update_state(y_true, y_pred); | ||||
@@ -183,17 +176,17 @@ public class MetricsTest : EagerModeTestBase | |||||
public void HammingLoss() | public void HammingLoss() | ||||
{ | { | ||||
// multi-class hamming loss | // multi-class hamming loss | ||||
var y_true = np.array(new[,] | |||||
{ | |||||
{ 1, 0, 0, 0 }, | |||||
{ 0, 0, 1, 0 }, | |||||
{ 0, 0, 0, 1 }, | |||||
{ 0, 1, 0, 0 } | |||||
var y_true = np.array(new[,] | |||||
{ | |||||
{ 1, 0, 0, 0 }, | |||||
{ 0, 0, 1, 0 }, | |||||
{ 0, 0, 0, 1 }, | |||||
{ 0, 1, 0, 0 } | |||||
}); | }); | ||||
var y_pred = np.array(new[,] | |||||
{ | |||||
{ 0.8f, 0.1f, 0.1f, 0.0f }, | |||||
{ 0.2f, 0.0f, 0.8f, 0.0f }, | |||||
var y_pred = np.array(new[,] | |||||
{ | |||||
{ 0.8f, 0.1f, 0.1f, 0.0f }, | |||||
{ 0.2f, 0.0f, 0.8f, 0.0f }, | |||||
{ 0.05f, 0.05f, 0.1f, 0.8f }, | { 0.05f, 0.05f, 0.1f, 0.8f }, | ||||
{ 1.0f, 0.0f, 0.0f, 0.0f } | { 1.0f, 0.0f, 0.0f, 0.0f } | ||||
}); | }); | ||||
@@ -0,0 +1,37 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.UnitTest.Model | |||||
{ | |||||
[TestClass] | |||||
public class ModelBuildTest | |||||
{ | |||||
[TestMethod] | |||||
public void DenseBuild() | |||||
{ | |||||
// two dimensions input with unknown batchsize | |||||
var input = tf.keras.layers.Input((17, 60)); | |||||
var dense = tf.keras.layers.Dense(64); | |||||
var output = dense.Apply(input); | |||||
var model = tf.keras.Model(input, output); | |||||
// one dimensions input with unknown batchsize | |||||
var input_2 = tf.keras.layers.Input((60)); | |||||
var dense_2 = tf.keras.layers.Dense(64); | |||||
var output_2 = dense.Apply(input_2); | |||||
var model_2 = tf.keras.Model(input_2, output_2); | |||||
// two dimensions input with specified batchsize | |||||
var input_3 = tf.keras.layers.Input((17, 60), 8); | |||||
var dense_3 = tf.keras.layers.Dense(64); | |||||
var output_3 = dense.Apply(input_3); | |||||
var model_3 = tf.keras.Model(input_3, output_3); | |||||
// one dimensions input with specified batchsize | |||||
var input_4 = tf.keras.layers.Input((60), 8); | |||||
var dense_4 = tf.keras.layers.Dense(64); | |||||
var output_4 = dense.Apply(input_4); | |||||
var model_4 = tf.keras.Model(input_4, output_4); | |||||
} | |||||
} | |||||
} |
@@ -1,18 +1,15 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Keras.UnitTest.Helpers; | using Tensorflow.Keras.UnitTest.Helpers; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||||
namespace Tensorflow.Keras.UnitTest.Model; | |||||
[TestClass] | [TestClass] | ||||
public class SequentialModelLoad | |||||
public class ModelLoadTest | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void SimpleModelFromAutoCompile() | public void SimpleModelFromAutoCompile() | ||||
@@ -46,7 +43,7 @@ public class SequentialModelLoad | |||||
[TestMethod] | [TestMethod] | ||||
public void AlexnetFromSequential() | public void AlexnetFromSequential() | ||||
{ | { | ||||
new SequentialModelSave().AlexnetFromSequential(); | |||||
new ModelSaveTest().AlexnetFromSequential(); | |||||
var model = tf.keras.models.load_model(@"./alexnet_from_sequential"); | var model = tf.keras.models.load_model(@"./alexnet_from_sequential"); | ||||
model.summary(); | model.summary(); | ||||
@@ -89,7 +86,7 @@ public class SequentialModelLoad | |||||
var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19"); | var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19"); | ||||
model.summary(); | model.summary(); | ||||
var classify_model = keras.Sequential(new System.Collections.Generic.List<Tensorflow.Keras.ILayer>() | |||||
var classify_model = keras.Sequential(new System.Collections.Generic.List<ILayer>() | |||||
{ | { | ||||
model, | model, | ||||
keras.layers.Flatten(), | keras.layers.Flatten(), | ||||
@@ -100,7 +97,7 @@ public class SequentialModelLoad | |||||
classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | ||||
var x = np.random.uniform(0, 1, (8, 512, 512, 3)); | var x = np.random.uniform(0, 1, (8, 512, 512, 3)); | ||||
var y = np.ones((8)); | |||||
var y = np.ones(8); | |||||
classify_model.fit(x, y, batch_size: 4); | classify_model.fit(x, y, batch_size: 4); | ||||
} | } | ||||
@@ -110,7 +107,7 @@ public class SequentialModelLoad | |||||
public void TestModelBeforeTF2_5() | public void TestModelBeforeTF2_5() | ||||
{ | { | ||||
var a = keras.layers; | var a = keras.layers; | ||||
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model; | |||||
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; | |||||
model.summary(); | model.summary(); | ||||
} | } | ||||
} | } |
@@ -0,0 +1,200 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Models; | |||||
using Tensorflow.Keras.Optimizers; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.UnitTest.Helpers; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.UnitTest.Model | |||||
{ | |||||
/// <summary> | |||||
/// https://www.tensorflow.org/guide/keras/save_and_serialize | |||||
/// </summary> | |||||
[TestClass] | |||||
public class ModelSaveTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void GetAndFromConfig() | |||||
{ | |||||
var model = GetFunctionalModel(); | |||||
var config = model.get_config(); | |||||
Debug.Assert(config is FunctionalConfig); | |||||
var new_model = new ModelsApi().from_config(config as FunctionalConfig); | |||||
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); | |||||
} | |||||
IModel GetFunctionalModel() | |||||
{ | |||||
// Create a simple model. | |||||
var inputs = keras.Input(shape: 32); | |||||
var dense_layer = keras.layers.Dense(1); | |||||
var outputs = dense_layer.Apply(inputs); | |||||
return keras.Model(inputs, outputs); | |||||
} | |||||
[TestMethod] | |||||
public void SimpleModelFromAutoCompile() | |||||
{ | |||||
var inputs = tf.keras.layers.Input((28, 28, 1)); | |||||
var x = tf.keras.layers.Flatten().Apply(inputs); | |||||
x = tf.keras.layers.Dense(100, activation: "relu").Apply(x); | |||||
x = tf.keras.layers.Dense(units: 10).Apply(x); | |||||
var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x); | |||||
var model = tf.keras.Model(inputs, outputs); | |||||
model.compile(new Adam(0.001f), | |||||
tf.keras.losses.SparseCategoricalCrossentropy(), | |||||
new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var num_epochs = 1; | |||||
var batch_size = 50; | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
model.save("./pb_simple_compile", save_format: "tf"); | |||||
} | |||||
[TestMethod] | |||||
public void SimpleModelFromSequential() | |||||
{ | |||||
var model = keras.Sequential(new List<ILayer>() | |||||
{ | |||||
tf.keras.layers.InputLayer((28, 28, 1)), | |||||
tf.keras.layers.Flatten(), | |||||
tf.keras.layers.Dense(100, "relu"), | |||||
tf.keras.layers.Dense(10), | |||||
tf.keras.layers.Softmax() | |||||
}); | |||||
model.summary(); | |||||
model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var num_epochs = 1; | |||||
var batch_size = 50; | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
model.save("./pb_simple_sequential", save_format: "tf"); | |||||
} | |||||
[TestMethod] | |||||
public void AlexnetFromSequential() | |||||
{ | |||||
var model = keras.Sequential(new List<ILayer>() | |||||
{ | |||||
tf.keras.layers.InputLayer((227, 227, 3)), | |||||
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), | |||||
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), | |||||
tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), | |||||
tf.keras.layers.Flatten(), | |||||
tf.keras.layers.Dense(4096, activation: "relu"), | |||||
tf.keras.layers.Dropout(0.5f), | |||||
tf.keras.layers.Dense(4096, activation: "relu"), | |||||
tf.keras.layers.Dropout(0.5f), | |||||
tf.keras.layers.Dense(1000, activation: "linear"), | |||||
tf.keras.layers.Softmax(1) | |||||
}); | |||||
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
var num_epochs = 1; | |||||
var batch_size = 8; | |||||
var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||||
model.save("./alexnet_from_sequential", save_format: "tf"); | |||||
// The saved model can be test with the following python code: | |||||
#region alexnet_python_code | |||||
//import pathlib | |||||
//import tensorflow as tf | |||||
//def func(a): | |||||
// return -a | |||||
//if __name__ == '__main__': | |||||
// model = tf.keras.models.load_model("./pb_alex_sequential") | |||||
// model.summary() | |||||
// num_classes = 5 | |||||
// batch_size = 128 | |||||
// img_height = 227 | |||||
// img_width = 227 | |||||
// epochs = 100 | |||||
// dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" | |||||
// data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True) | |||||
// data_dir = pathlib.Path(data_dir) | |||||
// train_ds = tf.keras.preprocessing.image_dataset_from_directory( | |||||
// data_dir, | |||||
// validation_split = 0.2, | |||||
// subset = "training", | |||||
// seed = 123, | |||||
// image_size = (img_height, img_width), | |||||
// batch_size = batch_size) | |||||
// val_ds = tf.keras.preprocessing.image_dataset_from_directory( | |||||
// data_dir, | |||||
// validation_split = 0.2, | |||||
// subset = "validation", | |||||
// seed = 123, | |||||
// image_size = (img_height, img_width), | |||||
// batch_size = batch_size) | |||||
// model.compile(optimizer = 'adam', | |||||
// loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), | |||||
// metrics =['accuracy']) | |||||
// model.build((None, img_height, img_width, 3)) | |||||
// history = model.fit( | |||||
// train_ds, | |||||
// validation_data = val_ds, | |||||
// epochs = epochs | |||||
// ) | |||||
#endregion | |||||
} | |||||
} | |||||
} |
@@ -1,11 +1,10 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MultiInputModelTest | public class MultiInputModelTest | ||||
@@ -1,12 +1,12 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | using System; | ||||
using System.Threading.Tasks; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using System.Threading.Tasks; | |||||
using Tensorflow.NumPy; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MultiThreads | public class MultiThreads | ||||
@@ -1,14 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow.Keras; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class OutputTest | public class OutputTest | ||||
@@ -1,14 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow; | |||||
using Tensorflow.Keras.Datasets; | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
namespace Tensorflow.Keras.UnitTest | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class PreprocessingTests : EagerModeTestBase | public class PreprocessingTests : EagerModeTestBase | ||||
@@ -71,8 +65,8 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
Assert.AreEqual(28, tokenizer.word_index.Count); | Assert.AreEqual(28, tokenizer.word_index.Count); | ||||
Assert.AreEqual(1, tokenizer.word_index[OOV]); | |||||
Assert.AreEqual(8, tokenizer.word_index["worst"]); | |||||
Assert.AreEqual(1, tokenizer.word_index[OOV]); | |||||
Assert.AreEqual(8, tokenizer.word_index["worst"]); | |||||
Assert.AreEqual(13, tokenizer.word_index["number"]); | Assert.AreEqual(13, tokenizer.word_index["number"]); | ||||
Assert.AreEqual(17, tokenizer.word_index["were"]); | Assert.AreEqual(17, tokenizer.word_index["were"]); | ||||
} | } | ||||
@@ -204,13 +198,13 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
for (var i = 0; i < sequences.Count; i++) | for (var i = 0; i < sequences.Count; i++) | ||||
for (var j = 0; j < sequences[i].Length; j++) | for (var j = 0; j < sequences[i].Length; j++) | ||||
Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]); | |||||
Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void TokenizeTextsToSequencesWithOOVPresent() | public void TokenizeTextsToSequencesWithOOVPresent() | ||||
{ | { | ||||
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words:20); | |||||
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words: 20); | |||||
tokenizer.fit_on_texts(texts); | tokenizer.fit_on_texts(texts); | ||||
var sequences = tokenizer.texts_to_sequences(texts); | var sequences = tokenizer.texts_to_sequences(texts); | ||||
@@ -255,7 +249,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
tokenizer.fit_on_texts(texts); | tokenizer.fit_on_texts(texts); | ||||
var sequences = tokenizer.texts_to_sequences(texts); | var sequences = tokenizer.texts_to_sequences(texts); | ||||
var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15); | |||||
var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15); | |||||
Assert.AreEqual(4, padded.dims[0]); | Assert.AreEqual(4, padded.dims[0]); | ||||
Assert.AreEqual(15, padded.dims[1]); | Assert.AreEqual(15, padded.dims[1]); | ||||
@@ -348,7 +342,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
Assert.AreEqual(27, tokenizer.word_index.Count); | Assert.AreEqual(27, tokenizer.word_index.Count); | ||||
var matrix = tokenizer.texts_to_matrix(texts, mode:"count"); | |||||
var matrix = tokenizer.texts_to_matrix(texts, mode: "count"); | |||||
Assert.AreEqual(texts.Length, matrix.dims[0]); | Assert.AreEqual(texts.Length, matrix.dims[0]); | ||||
@@ -1,176 +0,0 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System.Collections.Generic; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Optimizers; | |||||
using Tensorflow.Keras.UnitTest.Helpers; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||||
[TestClass] | |||||
public class SequentialModelSave | |||||
{ | |||||
[TestMethod] | |||||
public void SimpleModelFromAutoCompile() | |||||
{ | |||||
var inputs = tf.keras.layers.Input((28, 28, 1)); | |||||
var x = tf.keras.layers.Flatten().Apply(inputs); | |||||
x = tf.keras.layers.Dense(100, activation: "relu").Apply(x); | |||||
x = tf.keras.layers.Dense(units: 10).Apply(x); | |||||
var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x); | |||||
var model = tf.keras.Model(inputs, outputs); | |||||
model.compile(new Adam(0.001f), | |||||
tf.keras.losses.SparseCategoricalCrossentropy(), | |||||
new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var num_epochs = 1; | |||||
var batch_size = 50; | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
model.save("./pb_simple_compile", save_format: "tf"); | |||||
} | |||||
[TestMethod] | |||||
public void SimpleModelFromSequential() | |||||
{ | |||||
Model model = keras.Sequential(new List<ILayer>() | |||||
{ | |||||
tf.keras.layers.InputLayer((28, 28, 1)), | |||||
tf.keras.layers.Flatten(), | |||||
tf.keras.layers.Dense(100, "relu"), | |||||
tf.keras.layers.Dense(10), | |||||
tf.keras.layers.Softmax() | |||||
}); | |||||
model.summary(); | |||||
model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var num_epochs = 1; | |||||
var batch_size = 50; | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
model.save("./pb_simple_sequential", save_format: "tf"); | |||||
} | |||||
[TestMethod] | |||||
public void AlexnetFromSequential() | |||||
{ | |||||
Model model = keras.Sequential(new List<ILayer>() | |||||
{ | |||||
tf.keras.layers.InputLayer((227, 227, 3)), | |||||
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), | |||||
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), | |||||
tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"), | |||||
tf.keras.layers.BatchNormalization(), | |||||
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), | |||||
tf.keras.layers.Flatten(), | |||||
tf.keras.layers.Dense(4096, activation: "relu"), | |||||
tf.keras.layers.Dropout(0.5f), | |||||
tf.keras.layers.Dense(4096, activation: "relu"), | |||||
tf.keras.layers.Dropout(0.5f), | |||||
tf.keras.layers.Dense(1000, activation: "linear"), | |||||
tf.keras.layers.Softmax(1) | |||||
}); | |||||
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
var num_epochs = 1; | |||||
var batch_size = 8; | |||||
var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||||
model.save("./alexnet_from_sequential", save_format: "tf"); | |||||
// The saved model can be test with the following python code: | |||||
#region alexnet_python_code | |||||
//import pathlib | |||||
//import tensorflow as tf | |||||
//def func(a): | |||||
// return -a | |||||
//if __name__ == '__main__': | |||||
// model = tf.keras.models.load_model("./pb_alex_sequential") | |||||
// model.summary() | |||||
// num_classes = 5 | |||||
// batch_size = 128 | |||||
// img_height = 227 | |||||
// img_width = 227 | |||||
// epochs = 100 | |||||
// dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" | |||||
// data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True) | |||||
// data_dir = pathlib.Path(data_dir) | |||||
// train_ds = tf.keras.preprocessing.image_dataset_from_directory( | |||||
// data_dir, | |||||
// validation_split = 0.2, | |||||
// subset = "training", | |||||
// seed = 123, | |||||
// image_size = (img_height, img_width), | |||||
// batch_size = batch_size) | |||||
// val_ds = tf.keras.preprocessing.image_dataset_from_directory( | |||||
// data_dir, | |||||
// validation_split = 0.2, | |||||
// subset = "validation", | |||||
// seed = 123, | |||||
// image_size = (img_height, img_width), | |||||
// batch_size = batch_size) | |||||
// model.compile(optimizer = 'adam', | |||||
// loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), | |||||
// metrics =['accuracy']) | |||||
// model.build((None, img_height, img_width, 3)) | |||||
// history = model.fit( | |||||
// train_ds, | |||||
// validation_data = val_ds, | |||||
// epochs = epochs | |||||
// ) | |||||
#endregion | |||||
} | |||||
} |