You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelSave.cs 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. using Tensorflow.Keras;
  12. using Tensorflow.Keras.ArgsDefinition;
  13. using Tensorflow.Keras.Engine;
  14. using Tensorflow.Keras.Layers;
  15. using Tensorflow.Keras.Losses;
  16. using Tensorflow.Keras.Metrics;
  17. using Tensorflow.Keras.Optimizers;
  18. using Tensorflow.Operations;
  19. using System.Diagnostics;
  20. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  21. [TestClass]
  22. public class SequentialModelTest
  23. {
  24. [TestMethod]
  25. public void SimpleModelFromAutoCompile()
  26. {
  27. var inputs = new KerasInterface().Input((28, 28, 1));
  28. var x = new Flatten(new FlattenArgs()).Apply(inputs);
  29. x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x);
  30. x = new LayersApi().Dense(units: 10).Apply(x);
  31. var outputs = new LayersApi().Softmax(axis: 1).Apply(x);
  32. var model = new KerasInterface().Model(inputs, outputs);
  33. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  34. var data_loader = new MnistModelLoader();
  35. var num_epochs = 1;
  36. var batch_size = 50;
  37. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  38. {
  39. TrainDir = "mnist",
  40. OneHot = false,
  41. ValidationSize = 10000,
  42. }).Result;
  43. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  44. model.save("./pb_simple_compile", save_format: "tf");
  45. }
  46. [TestMethod]
  47. public void SimpleModelFromSequential()
  48. {
  49. Model model = KerasApi.keras.Sequential(new List<ILayer>()
  50. {
  51. keras.layers.InputLayer((28, 28, 1)),
  52. keras.layers.Flatten(),
  53. keras.layers.Dense(100, "relu"),
  54. keras.layers.Dense(10),
  55. keras.layers.Softmax(1)
  56. });
  57. model.summary();
  58. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  59. var data_loader = new MnistModelLoader();
  60. var num_epochs = 1;
  61. var batch_size = 50;
  62. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  63. {
  64. TrainDir = "mnist",
  65. OneHot = false,
  66. ValidationSize = 50000,
  67. }).Result;
  68. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  69. model.save("./pb_simple_sequential", save_format: "tf");
  70. }
  71. [TestMethod]
  72. public void AlexModelFromSequential()
  73. {
  74. Model model = KerasApi.keras.Sequential(new List<ILayer>()
  75. {
  76. keras.layers.InputLayer((227, 227, 3)),
  77. keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  78. keras.layers.BatchNormalization(),
  79. keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  80. keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
  81. keras.layers.BatchNormalization(),
  82. keras.layers.MaxPooling2D((3, 3), (2, 2)),
  83. keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  84. keras.layers.BatchNormalization(),
  85. keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  86. keras.layers.BatchNormalization(),
  87. keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
  88. keras.layers.BatchNormalization(),
  89. keras.layers.MaxPooling2D((3, 3), (2, 2)),
  90. keras.layers.Flatten(),
  91. keras.layers.Dense(4096, activation: "relu"),
  92. keras.layers.Dropout(0.5f),
  93. keras.layers.Dense(4096, activation: "relu"),
  94. keras.layers.Dropout(0.5f),
  95. keras.layers.Dense(1000, activation: "linear"),
  96. keras.layers.Softmax(1)
  97. });
  98. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" });
  99. var num_epochs = 1;
  100. var batch_size = 8;
  101. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  102. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  103. model.save("./pb_alex_sequential", save_format: "tf");
  104. // The saved model can be test with the following python code:
  105. #region alexnet_python_code
  106. //import pathlib
  107. //import tensorflow as tf
  108. //def func(a):
  109. // return -a
  110. //if __name__ == '__main__':
  111. // model = tf.keras.models.load_model("./pb_alex_sequential")
  112. // model.summary()
  113. // num_classes = 5
  114. // batch_size = 128
  115. // img_height = 227
  116. // img_width = 227
  117. // epochs = 100
  118. // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
  119. // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
  120. // data_dir = pathlib.Path(data_dir)
  121. // train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  122. // data_dir,
  123. // validation_split = 0.2,
  124. // subset = "training",
  125. // seed = 123,
  126. // image_size = (img_height, img_width),
  127. // batch_size = batch_size)
  128. // val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  129. // data_dir,
  130. // validation_split = 0.2,
  131. // subset = "validation",
  132. // seed = 123,
  133. // image_size = (img_height, img_width),
  134. // batch_size = batch_size)
  135. // model.compile(optimizer = 'adam',
  136. // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
  137. // metrics =['accuracy'])
  138. // model.build((None, img_height, img_width, 3))
  139. // history = model.fit(
  140. // train_ds,
  141. // validation_data = val_ds,
  142. // epochs = epochs
  143. // )
  144. #endregion
  145. }
  146. public class RandomDataSet : DataSetBase
  147. {
  148. private Shape _shape;
  149. public RandomDataSet(Shape shape, int count)
  150. {
  151. _shape = shape;
  152. Debug.Assert(_shape.ndim == 3);
  153. long[] dims = new long[4];
  154. dims[0] = count;
  155. for (int i = 1; i < 4; i++)
  156. {
  157. dims[i] = _shape[i - 1];
  158. }
  159. Shape s = new Shape(dims);
  160. Data = np.random.normal(0, 2, s);
  161. Labels = np.random.uniform(0, 1, (count, 1));
  162. }
  163. }
  164. }