You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelLoadTest.cs 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. using Microsoft.VisualStudio.TestPlatform.Utilities;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using Newtonsoft.Json.Linq;
  4. using System.Linq;
  5. using System.Xml.Linq;
  6. using Tensorflow.Keras.Engine;
  7. using Tensorflow.Keras.Optimizers;
  8. using Tensorflow.Keras.UnitTest.Helpers;
  9. using Tensorflow.NumPy;
  10. using static HDF.PInvoke.H5Z;
  11. using static Tensorflow.Binding;
  12. using static Tensorflow.KerasApi;
  13. namespace Tensorflow.Keras.UnitTest.Model;
  14. [TestClass]
  15. public class ModelLoadTest
  16. {
  17. [TestMethod]
  18. public void SimpleModelFromAutoCompile()
  19. {
  20. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  21. model.summary();
  22. model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  23. // check the weights
  24. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  25. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  26. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  27. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  28. var data_loader = new MnistModelLoader();
  29. var num_epochs = 1;
  30. var batch_size = 8;
  31. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  32. {
  33. TrainDir = "mnist",
  34. OneHot = false,
  35. ValidationSize = 58000,
  36. }).Result;
  37. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  38. }
  39. [TestMethod]
  40. public void AlexnetFromSequential()
  41. {
  42. new ModelSaveTest().AlexnetFromSequential();
  43. var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
  44. model.summary();
  45. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  46. var num_epochs = 1;
  47. var batch_size = 8;
  48. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  49. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  50. }
  51. [TestMethod]
  52. public void ModelWithSelfDefinedModule()
  53. {
  54. var model = tf.keras.models.load_model(@"Assets/python_func_model");
  55. model.summary();
  56. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  57. var data_loader = new MnistModelLoader();
  58. var num_epochs = 1;
  59. var batch_size = 8;
  60. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  61. {
  62. TrainDir = "mnist",
  63. OneHot = false,
  64. ValidationSize = 55000,
  65. }).Result;
  66. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  67. }
  68. [Ignore]
  69. [TestMethod]
  70. public void LSTMLoad()
  71. {
  72. var model = tf.keras.models.load_model(@"Assets/lstm_from_sequential");
  73. model.summary();
  74. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.MeanSquaredError(), new string[] { "accuracy" });
  75. var inputs = tf.random.normal(shape: (10, 5, 3));
  76. var outputs = tf.random.normal(shape: (10, 1));
  77. model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 5, workers: 16, use_multiprocessing: true);
  78. }
  79. [Ignore]
  80. [TestMethod]
  81. public void VGG19()
  82. {
  83. var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
  84. model.summary();
  85. var classify_model = keras.Sequential(new System.Collections.Generic.List<ILayer>()
  86. {
  87. model,
  88. keras.layers.Flatten(),
  89. keras.layers.Dense(10),
  90. });
  91. classify_model.summary();
  92. classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  93. var x = np.random.uniform(0, 1, (8, 512, 512, 3));
  94. var y = np.ones(8);
  95. classify_model.fit(x, y, batch_size: 4);
  96. }
  97. [Ignore]
  98. [TestMethod]
  99. public void TestModelBeforeTF2_5()
  100. {
  101. var a = keras.layers;
  102. var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
  103. model.summary();
  104. }
  105. [TestMethod]
  106. public void CreateConcatenateModelSaveAndLoad()
  107. {
  108. // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded.
  109. var input_layer = tf.keras.layers.Input((8, 8, 5));
  110. var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer);
  111. conv1.Name = "conv1";
  112. var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer);
  113. conv2.Name = "conv2";
  114. var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2));
  115. concat1.Name = "concat1";
  116. var model = tf.keras.Model(input_layer, concat1);
  117. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  118. model.save(@"Assets/concat_axis3_model");
  119. var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT);
  120. var tensors1 = model.predict(tensorInput);
  121. Assert.AreEqual((1, 8, 8, 4), tensors1.shape);
  122. model = null;
  123. keras.backend.clear_session();
  124. var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model");
  125. var tensors2 = model2.predict(tensorInput);
  126. Assert.AreEqual(tensors1.shape, tensors2.shape);
  127. }
  128. }