You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelLoadTest.cs 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. using Microsoft.VisualStudio.TestPlatform.Utilities;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using Newtonsoft.Json.Linq;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Xml.Linq;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Optimizers;
  9. using Tensorflow.Keras.UnitTest.Helpers;
  10. using Tensorflow.NumPy;
  11. using static HDF.PInvoke.H5Z;
  12. using static Tensorflow.Binding;
  13. using static Tensorflow.KerasApi;
  14. namespace Tensorflow.Keras.UnitTest.Model;
  15. [TestClass]
  16. public class ModelLoadTest
  17. {
  18. [TestMethod]
  19. public void SimpleModelFromAutoCompile()
  20. {
  21. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  22. model.summary();
  23. model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  24. // check the weights
  25. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  26. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  27. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  28. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  29. var data_loader = new MnistModelLoader();
  30. var num_epochs = 1;
  31. var batch_size = 8;
  32. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  33. {
  34. TrainDir = "mnist",
  35. OneHot = false,
  36. ValidationSize = 58000,
  37. }).Result;
  38. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  39. }
  40. [TestMethod]
  41. public void AlexnetFromSequential()
  42. {
  43. new ModelSaveTest().AlexnetFromSequential();
  44. var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
  45. model.summary();
  46. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  47. var num_epochs = 1;
  48. var batch_size = 8;
  49. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  50. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  51. }
  52. [TestMethod]
  53. public void ModelWithSelfDefinedModule()
  54. {
  55. var model = tf.keras.models.load_model(@"Assets/python_func_model");
  56. model.summary();
  57. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  58. var data_loader = new MnistModelLoader();
  59. var num_epochs = 1;
  60. var batch_size = 8;
  61. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  62. {
  63. TrainDir = "mnist",
  64. OneHot = false,
  65. ValidationSize = 55000,
  66. }).Result;
  67. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  68. }
  69. [Ignore]
  70. [TestMethod]
  71. public void LSTMLoad()
  72. {
  73. var model = tf.keras.models.load_model(@"Assets/lstm_from_sequential");
  74. model.summary();
  75. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.MeanSquaredError(), new string[] { "accuracy" });
  76. var inputs = tf.random.normal(shape: (10, 5, 3));
  77. var outputs = tf.random.normal(shape: (10, 1));
  78. model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 5, workers: 16, use_multiprocessing: true);
  79. }
  80. [Ignore]
  81. [TestMethod]
  82. public void VGG19()
  83. {
  84. var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
  85. model.summary();
  86. var classify_model = keras.Sequential(new System.Collections.Generic.List<ILayer>()
  87. {
  88. model,
  89. keras.layers.Flatten(),
  90. keras.layers.Dense(10),
  91. });
  92. classify_model.summary();
  93. classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  94. var x = np.random.uniform(0, 1, (8, 512, 512, 3));
  95. var y = np.ones(8);
  96. classify_model.fit(x, y, batch_size: 4);
  97. }
  98. [Ignore]
  99. [TestMethod]
  100. public void TestModelBeforeTF2_5()
  101. {
  102. var a = keras.layers;
  103. var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
  104. model.summary();
  105. }
  106. [TestMethod]
  107. public void BiasRegularizerSaveAndLoad()
  108. {
  109. var savemodel = keras.Sequential(new List<ILayer>()
  110. {
  111. tf.keras.layers.InputLayer((227, 227, 3)),
  112. tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  113. tf.keras.layers.BatchNormalization(),
  114. tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  115. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2),
  116. tf.keras.layers.BatchNormalization(),
  117. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2),
  118. tf.keras.layers.BatchNormalization(),
  119. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1),
  120. tf.keras.layers.BatchNormalization(),
  121. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  122. tf.keras.layers.Flatten(),
  123. tf.keras.layers.Dense(1000, activation: "linear"),
  124. tf.keras.layers.Softmax(1)
  125. });
  126. savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  127. var num_epochs = 1;
  128. var batch_size = 8;
  129. var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  130. savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs);
  131. savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf");
  132. var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load");
  133. loadModel.summary();
  134. loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  135. var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  136. loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs);
  137. }
  138. [TestMethod]
  139. public void CreateConcatenateModelSaveAndLoad()
  140. {
  141. // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded.
  142. var input_layer = tf.keras.layers.Input((8, 8, 5));
  143. var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer);
  144. conv1.Name = "conv1";
  145. var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer);
  146. conv2.Name = "conv2";
  147. var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2));
  148. concat1.Name = "concat1";
  149. var model = tf.keras.Model(input_layer, concat1);
  150. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  151. model.save(@"Assets/concat_axis3_model");
  152. var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT);
  153. var tensors1 = model.predict(tensorInput);
  154. Assert.AreEqual((1, 8, 8, 4), tensors1.shape);
  155. model = null;
  156. keras.backend.clear_session();
  157. var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model");
  158. var tensors2 = model2.predict(tensorInput);
  159. Assert.AreEqual(tensors1.shape, tensors2.shape);
  160. }
  161. }