You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelLoadTest.cs 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using Microsoft.VisualStudio.TestPlatform.Utilities;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using System.Linq;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Optimizers;
  6. using Tensorflow.Keras.UnitTest.Helpers;
  7. using Tensorflow.NumPy;
  8. using static Tensorflow.Binding;
  9. using static Tensorflow.KerasApi;
  10. namespace Tensorflow.Keras.UnitTest.Model;
  11. [TestClass]
  12. public class ModelLoadTest
  13. {
  14. [TestMethod]
  15. public void SimpleModelFromAutoCompile()
  16. {
  17. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  18. model.summary();
  19. model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  20. // check the weights
  21. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  22. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  23. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  24. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  25. var data_loader = new MnistModelLoader();
  26. var num_epochs = 1;
  27. var batch_size = 8;
  28. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  29. {
  30. TrainDir = "mnist",
  31. OneHot = false,
  32. ValidationSize = 58000,
  33. }).Result;
  34. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  35. }
  36. [TestMethod]
  37. public void AlexnetFromSequential()
  38. {
  39. new ModelSaveTest().AlexnetFromSequential();
  40. var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
  41. model.summary();
  42. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  43. var num_epochs = 1;
  44. var batch_size = 8;
  45. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  46. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  47. }
  48. [TestMethod]
  49. public void ModelWithSelfDefinedModule()
  50. {
  51. var model = tf.keras.models.load_model(@"Assets/python_func_model");
  52. model.summary();
  53. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  54. var data_loader = new MnistModelLoader();
  55. var num_epochs = 1;
  56. var batch_size = 8;
  57. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  58. {
  59. TrainDir = "mnist",
  60. OneHot = false,
  61. ValidationSize = 55000,
  62. }).Result;
  63. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  64. }
  65. [TestMethod]
  66. public void LSTMLoad()
  67. {
  68. var inputs = np.random.randn(10, 5, 3);
  69. var outputs = np.random.randn(10, 1);
  70. var model = keras.Sequential();
  71. model.add(keras.Input(shape: (5, 3)));
  72. var lstm = keras.layers.LSTM(32);
  73. model.add(lstm);
  74. model.add(keras.layers.Dense(1, keras.activations.Sigmoid));
  75. model.compile(optimizer: keras.optimizers.Adam(),
  76. loss: keras.losses.MeanSquaredError(),
  77. new[] { "accuracy" });
  78. var result = model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 3, workers: 16, use_multiprocessing: true);
  79. model.save("LSTM_Random");
  80. var model_loaded = keras.models.load_model("LSTM_Random");
  81. model_loaded.summary();
  82. }
  83. [Ignore]
  84. [TestMethod]
  85. public void VGG19()
  86. {
  87. var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
  88. model.summary();
  89. var classify_model = keras.Sequential(new System.Collections.Generic.List<ILayer>()
  90. {
  91. model,
  92. keras.layers.Flatten(),
  93. keras.layers.Dense(10),
  94. });
  95. classify_model.summary();
  96. classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  97. var x = np.random.uniform(0, 1, (8, 512, 512, 3));
  98. var y = np.ones(8);
  99. classify_model.fit(x, y, batch_size: 4);
  100. }
  101. [Ignore]
  102. [TestMethod]
  103. public void TestModelBeforeTF2_5()
  104. {
  105. var a = keras.layers;
  106. var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
  107. model.summary();
  108. }
  109. }