|
|
@@ -1,10 +1,13 @@ |
|
|
|
using Microsoft.VisualStudio.TestPlatform.Utilities; |
|
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting; |
|
|
|
using Newtonsoft.Json.Linq; |
|
|
|
using System.Linq; |
|
|
|
using System.Xml.Linq; |
|
|
|
using Tensorflow.Keras.Engine; |
|
|
|
using Tensorflow.Keras.Optimizers; |
|
|
|
using Tensorflow.Keras.UnitTest.Helpers; |
|
|
|
using Tensorflow.NumPy; |
|
|
|
using static HDF.PInvoke.H5Z; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
using static Tensorflow.KerasApi; |
|
|
|
|
|
|
@@ -124,4 +127,44 @@ public class ModelLoadTest |
|
|
|
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; |
|
|
|
model.summary(); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public void CreateConcatenateModelSaveAndLoad() |
|
|
|
{ |
|
|
|
// a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded. |
|
|
|
var input_layer = tf.keras.layers.Input((8, 8, 5)); |
|
|
|
|
|
|
|
var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer); |
|
|
|
conv1.Name = "conv1"; |
|
|
|
|
|
|
|
var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer); |
|
|
|
conv2.Name = "conv2"; |
|
|
|
|
|
|
|
var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2)); |
|
|
|
concat1.Name = "concat1"; |
|
|
|
|
|
|
|
var model = tf.keras.Model(input_layer, concat1); |
|
|
|
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); |
|
|
|
|
|
|
|
model.save(@"Assets/concat_axis3_model"); |
|
|
|
|
|
|
|
|
|
|
|
var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT); |
|
|
|
|
|
|
|
var tensors1 = model.predict(tensorInput); |
|
|
|
|
|
|
|
Assert.AreEqual((1, 8, 8, 4), tensors1.shape); |
|
|
|
|
|
|
|
model = null; |
|
|
|
keras.backend.clear_session(); |
|
|
|
|
|
|
|
var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model"); |
|
|
|
|
|
|
|
var tensors2 = model2.predict(tensorInput); |
|
|
|
|
|
|
|
Assert.AreEqual(tensors1.shape, tensors2.shape); |
|
|
|
} |
|
|
|
|
|
|
|
} |