You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSaveTest.cs 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Models;
  6. using Tensorflow.Keras.Optimizers;
  7. using Tensorflow.Keras.Saving;
  8. using Tensorflow.Keras.UnitTest.Helpers;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. namespace Tensorflow.Keras.UnitTest.Model
  12. {
  13. /// <summary>
  14. /// https://www.tensorflow.org/guide/keras/save_and_serialize
  15. /// </summary>
  16. [TestClass]
  17. public class ModelSaveTest : EagerModeTestBase
  18. {
  19. [TestMethod]
  20. public void GetAndFromConfig()
  21. {
  22. var model = GetFunctionalModel();
  23. var config = model.get_config();
  24. Debug.Assert(config is FunctionalConfig);
  25. var new_model = new ModelsApi().from_config(config as FunctionalConfig);
  26. Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
  27. }
  28. IModel GetFunctionalModel()
  29. {
  30. // Create a simple model.
  31. var inputs = keras.Input(shape: 32);
  32. var dense_layer = keras.layers.Dense(1);
  33. var outputs = dense_layer.Apply(inputs);
  34. return keras.Model(inputs, outputs);
  35. }
  36. [TestMethod]
  37. public void SimpleModelFromAutoCompile()
  38. {
  39. var inputs = tf.keras.layers.Input((28, 28, 1));
  40. var x = tf.keras.layers.Flatten().Apply(inputs);
  41. x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
  42. x = tf.keras.layers.Dense(units: 10).Apply(x);
  43. var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
  44. var model = tf.keras.Model(inputs, outputs);
  45. model.compile(new Adam(0.001f),
  46. tf.keras.losses.SparseCategoricalCrossentropy(),
  47. new string[] { "accuracy" });
  48. var data_loader = new MnistModelLoader();
  49. var num_epochs = 1;
  50. var batch_size = 50;
  51. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  52. {
  53. TrainDir = "mnist",
  54. OneHot = false,
  55. ValidationSize = 58000,
  56. }).Result;
  57. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  58. model.save("./pb_simple_compile", save_format: "tf");
  59. }
  60. [TestMethod]
  61. public void SimpleModelFromSequential()
  62. {
  63. var model = keras.Sequential(new List<ILayer>()
  64. {
  65. tf.keras.layers.InputLayer((28, 28, 1)),
  66. tf.keras.layers.Flatten(),
  67. tf.keras.layers.Dense(100, "relu"),
  68. tf.keras.layers.Dense(10),
  69. tf.keras.layers.Softmax()
  70. });
  71. model.summary();
  72. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  73. var data_loader = new MnistModelLoader();
  74. var num_epochs = 1;
  75. var batch_size = 50;
  76. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  77. {
  78. TrainDir = "mnist",
  79. OneHot = false,
  80. ValidationSize = 58000,
  81. }).Result;
  82. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  83. model.save("./pb_simple_sequential", save_format: "tf");
  84. }
  85. [TestMethod]
  86. public void AlexnetFromSequential()
  87. {
  88. var model = keras.Sequential(new List<ILayer>()
  89. {
  90. tf.keras.layers.InputLayer((227, 227, 3)),
  91. tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  92. tf.keras.layers.BatchNormalization(),
  93. tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  94. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
  95. tf.keras.layers.BatchNormalization(),
  96. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  97. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  98. tf.keras.layers.BatchNormalization(),
  99. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  100. tf.keras.layers.BatchNormalization(),
  101. tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
  102. tf.keras.layers.BatchNormalization(),
  103. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  104. tf.keras.layers.Flatten(),
  105. tf.keras.layers.Dense(4096, activation: "relu"),
  106. tf.keras.layers.Dropout(0.5f),
  107. tf.keras.layers.Dense(4096, activation: "relu"),
  108. tf.keras.layers.Dropout(0.5f),
  109. tf.keras.layers.Dense(1000, activation: "linear"),
  110. tf.keras.layers.Softmax(1)
  111. });
  112. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  113. var num_epochs = 1;
  114. var batch_size = 8;
  115. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  116. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  117. model.save("./alexnet_from_sequential", save_format: "tf");
  118. // The saved model can be test with the following python code:
  119. #region alexnet_python_code
  120. //import pathlib
  121. //import tensorflow as tf
  122. //def func(a):
  123. // return -a
  124. //if __name__ == '__main__':
  125. // model = tf.keras.models.load_model("./pb_alex_sequential")
  126. // model.summary()
  127. // num_classes = 5
  128. // batch_size = 128
  129. // img_height = 227
  130. // img_width = 227
  131. // epochs = 100
  132. // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
  133. // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
  134. // data_dir = pathlib.Path(data_dir)
  135. // train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  136. // data_dir,
  137. // validation_split = 0.2,
  138. // subset = "training",
  139. // seed = 123,
  140. // image_size = (img_height, img_width),
  141. // batch_size = batch_size)
  142. // val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  143. // data_dir,
  144. // validation_split = 0.2,
  145. // subset = "validation",
  146. // seed = 123,
  147. // image_size = (img_height, img_width),
  148. // batch_size = batch_size)
  149. // model.compile(optimizer = 'adam',
  150. // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
  151. // metrics =['accuracy'])
  152. // model.build((None, img_height, img_width, 3))
  153. // history = model.fit(
  154. // train_ds,
  155. // validation_data = val_ds,
  156. // epochs = epochs
  157. // )
  158. #endregion
  159. }
  160. [TestMethod]
  161. public void SaveAfterLoad()
  162. {
  163. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  164. model.summary();
  165. model.save("Assets/saved_auto_compile_after_loading");
  166. //model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading");
  167. //model.summary();
  168. }
  169. }
  170. }