You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelSave.cs 6.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using Tensorflow;
  4. using Tensorflow.Keras;
  5. using Tensorflow.Keras.Engine;
  6. using Tensorflow.Keras.Optimizers;
  7. using Tensorflow.Keras.UnitTest.Helpers;
  8. using static Tensorflow.Binding;
  9. using static Tensorflow.KerasApi;
  10. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  11. [TestClass]
  12. public class SequentialModelSave
  13. {
  14. [TestMethod]
  15. public void SimpleModelFromAutoCompile()
  16. {
  17. var inputs = tf.keras.layers.Input((28, 28, 1));
  18. var x = tf.keras.layers.Flatten().Apply(inputs);
  19. x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
  20. x = tf.keras.layers.Dense(units: 10).Apply(x);
  21. var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
  22. var model = tf.keras.Model(inputs, outputs);
  23. model.compile(new Adam(0.001f),
  24. tf.keras.losses.SparseCategoricalCrossentropy(),
  25. new string[] { "accuracy" });
  26. var data_loader = new MnistModelLoader();
  27. var num_epochs = 1;
  28. var batch_size = 50;
  29. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  30. {
  31. TrainDir = "mnist",
  32. OneHot = false,
  33. ValidationSize = 58000,
  34. }).Result;
  35. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  36. model.save("./pb_simple_compile", save_format: "tf");
  37. }
  38. [TestMethod]
  39. public void SimpleModelFromSequential()
  40. {
  41. Model model = keras.Sequential(new List<ILayer>()
  42. {
  43. tf.keras.layers.InputLayer((28, 28, 1)),
  44. tf.keras.layers.Flatten(),
  45. tf.keras.layers.Dense(100, "relu"),
  46. tf.keras.layers.Dense(10),
  47. tf.keras.layers.Softmax()
  48. });
  49. model.summary();
  50. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  51. var data_loader = new MnistModelLoader();
  52. var num_epochs = 1;
  53. var batch_size = 50;
  54. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  55. {
  56. TrainDir = "mnist",
  57. OneHot = false,
  58. ValidationSize = 58000,
  59. }).Result;
  60. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  61. model.save("./pb_simple_sequential", save_format: "tf");
  62. }
  63. [TestMethod]
  64. public void AlexnetFromSequential()
  65. {
  66. Model model = keras.Sequential(new List<ILayer>()
  67. {
  68. tf.keras.layers.InputLayer((227, 227, 3)),
  69. tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  70. tf.keras.layers.BatchNormalization(),
  71. tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  72. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
  73. tf.keras.layers.BatchNormalization(),
  74. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  75. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  76. tf.keras.layers.BatchNormalization(),
  77. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  78. tf.keras.layers.BatchNormalization(),
  79. tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
  80. tf.keras.layers.BatchNormalization(),
  81. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  82. tf.keras.layers.Flatten(),
  83. tf.keras.layers.Dense(4096, activation: "relu"),
  84. tf.keras.layers.Dropout(0.5f),
  85. tf.keras.layers.Dense(4096, activation: "relu"),
  86. tf.keras.layers.Dropout(0.5f),
  87. tf.keras.layers.Dense(1000, activation: "linear"),
  88. tf.keras.layers.Softmax(1)
  89. });
  90. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  91. var num_epochs = 1;
  92. var batch_size = 8;
  93. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  94. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  95. model.save("./alexnet_from_sequential", save_format: "tf");
  96. // The saved model can be test with the following python code:
  97. #region alexnet_python_code
  98. //import pathlib
  99. //import tensorflow as tf
  100. //def func(a):
  101. // return -a
  102. //if __name__ == '__main__':
  103. // model = tf.keras.models.load_model("./pb_alex_sequential")
  104. // model.summary()
  105. // num_classes = 5
  106. // batch_size = 128
  107. // img_height = 227
  108. // img_width = 227
  109. // epochs = 100
  110. // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
  111. // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
  112. // data_dir = pathlib.Path(data_dir)
  113. // train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  114. // data_dir,
  115. // validation_split = 0.2,
  116. // subset = "training",
  117. // seed = 123,
  118. // image_size = (img_height, img_width),
  119. // batch_size = batch_size)
  120. // val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  121. // data_dir,
  122. // validation_split = 0.2,
  123. // subset = "validation",
  124. // seed = 123,
  125. // image_size = (img_height, img_width),
  126. // batch_size = batch_size)
  127. // model.compile(optimizer = 'adam',
  128. // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
  129. // metrics =['accuracy'])
  130. // model.build((None, img_height, img_width, 3))
  131. // history = model.fit(
  132. // train_ds,
  133. // validation_data = val_ds,
  134. // epochs = epochs
  135. // )
  136. #endregion
  137. }
  138. }