You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelSave.cs 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using Tensorflow;
  5. using Tensorflow.Keras;
  6. using Tensorflow.Keras.ArgsDefinition;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Layers;
  9. using Tensorflow.Keras.Losses;
  10. using Tensorflow.Keras.Optimizers;
  11. using Tensorflow.NumPy;
  12. using static Tensorflow.Binding;
  13. using static Tensorflow.KerasApi;
  14. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  15. [TestClass]
  16. public class SequentialModelSave
  17. {
  18. [TestMethod]
  19. public void SimpleModelFromAutoCompile()
  20. {
  21. var inputs = new KerasInterface().Input((28, 28, 1));
  22. var x = new Flatten(new FlattenArgs()).Apply(inputs);
  23. x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x);
  24. x = new LayersApi().Dense(units: 10).Apply(x);
  25. var outputs = new LayersApi().Softmax(axis: 1).Apply(x);
  26. var model = new KerasInterface().Model(inputs, outputs);
  27. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  28. var data_loader = new MnistModelLoader();
  29. var num_epochs = 1;
  30. var batch_size = 50;
  31. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  32. {
  33. TrainDir = "mnist",
  34. OneHot = false,
  35. ValidationSize = 10000,
  36. }).Result;
  37. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  38. model.save("./pb_simple_compile", save_format: "tf");
  39. }
  40. [TestMethod]
  41. public void SimpleModelFromSequential()
  42. {
  43. Model model = KerasApi.keras.Sequential(new List<ILayer>()
  44. {
  45. keras.layers.InputLayer((28, 28, 1)),
  46. keras.layers.Flatten(),
  47. keras.layers.Dense(100, "relu"),
  48. keras.layers.Dense(10),
  49. keras.layers.Softmax(1)
  50. });
  51. model.summary();
  52. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  53. var data_loader = new MnistModelLoader();
  54. var num_epochs = 1;
  55. var batch_size = 50;
  56. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  57. {
  58. TrainDir = "mnist",
  59. OneHot = false,
  60. ValidationSize = 50000,
  61. }).Result;
  62. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  63. model.save("./pb_simple_sequential", save_format: "tf");
  64. }
  65. [TestMethod]
  66. public void AlexnetFromSequential()
  67. {
  68. Model model = KerasApi.keras.Sequential(new List<ILayer>()
  69. {
  70. keras.layers.InputLayer((227, 227, 3)),
  71. keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  72. keras.layers.BatchNormalization(),
  73. keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  74. keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
  75. keras.layers.BatchNormalization(),
  76. keras.layers.MaxPooling2D((3, 3), (2, 2)),
  77. keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  78. keras.layers.BatchNormalization(),
  79. keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  80. keras.layers.BatchNormalization(),
  81. keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
  82. keras.layers.BatchNormalization(),
  83. keras.layers.MaxPooling2D((3, 3), (2, 2)),
  84. keras.layers.Flatten(),
  85. keras.layers.Dense(4096, activation: "relu"),
  86. keras.layers.Dropout(0.5f),
  87. keras.layers.Dense(4096, activation: "relu"),
  88. keras.layers.Dropout(0.5f),
  89. keras.layers.Dense(1000, activation: "linear"),
  90. keras.layers.Softmax(1)
  91. });
  92. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  93. var num_epochs = 1;
  94. var batch_size = 8;
  95. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  96. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  97. model.save("./alexnet_from_sequential", save_format: "tf");
  98. // The saved model can be test with the following python code:
  99. #region alexnet_python_code
  100. //import pathlib
  101. //import tensorflow as tf
  102. //def func(a):
  103. // return -a
  104. //if __name__ == '__main__':
  105. // model = tf.keras.models.load_model("./pb_alex_sequential")
  106. // model.summary()
  107. // num_classes = 5
  108. // batch_size = 128
  109. // img_height = 227
  110. // img_width = 227
  111. // epochs = 100
  112. // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
  113. // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
  114. // data_dir = pathlib.Path(data_dir)
  115. // train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  116. // data_dir,
  117. // validation_split = 0.2,
  118. // subset = "training",
  119. // seed = 123,
  120. // image_size = (img_height, img_width),
  121. // batch_size = batch_size)
  122. // val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  123. // data_dir,
  124. // validation_split = 0.2,
  125. // subset = "validation",
  126. // seed = 123,
  127. // image_size = (img_height, img_width),
  128. // batch_size = batch_size)
  129. // model.compile(optimizer = 'adam',
  130. // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
  131. // metrics =['accuracy'])
  132. // model.build((None, img_height, img_width, 3))
  133. // history = model.fit(
  134. // train_ds,
  135. // validation_data = val_ds,
  136. // epochs = epochs
  137. // )
  138. #endregion
  139. }
  140. public class RandomDataSet : DataSetBase
  141. {
  142. private Shape _shape;
  143. public RandomDataSet(Shape shape, int count)
  144. {
  145. _shape = shape;
  146. Debug.Assert(_shape.ndim == 3);
  147. long[] dims = new long[4];
  148. dims[0] = count;
  149. for (int i = 1; i < 4; i++)
  150. {
  151. dims[i] = _shape[i - 1];
  152. }
  153. Shape s = new Shape(dims);
  154. Data = np.random.normal(0, 2, s);
  155. Labels = np.random.uniform(0, 1, (count, 1));
  156. }
  157. }
  158. }