You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelLoadTest.cs 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow.Keras.Optimizers;
  4. using Tensorflow.Keras.UnitTest.Helpers;
  5. using Tensorflow.NumPy;
  6. using static Tensorflow.Binding;
  7. using static Tensorflow.KerasApi;
  8. namespace Tensorflow.Keras.UnitTest.Model;
  9. [TestClass]
  10. public class ModelLoadTest
  11. {
  12. [TestMethod]
  13. public void SimpleModelFromAutoCompile()
  14. {
  15. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  16. model.summary();
  17. model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  18. // check the weights
  19. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  20. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  21. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  22. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  23. var data_loader = new MnistModelLoader();
  24. var num_epochs = 1;
  25. var batch_size = 8;
  26. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  27. {
  28. TrainDir = "mnist",
  29. OneHot = false,
  30. ValidationSize = 58000,
  31. }).Result;
  32. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  33. }
  34. [TestMethod]
  35. public void AlexnetFromSequential()
  36. {
  37. new ModelSaveTest().AlexnetFromSequential();
  38. var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
  39. model.summary();
  40. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  41. var num_epochs = 1;
  42. var batch_size = 8;
  43. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  44. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  45. }
  46. [TestMethod]
  47. public void ModelWithSelfDefinedModule()
  48. {
  49. var model = tf.keras.models.load_model(@"Assets/python_func_model");
  50. model.summary();
  51. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  52. var data_loader = new MnistModelLoader();
  53. var num_epochs = 1;
  54. var batch_size = 8;
  55. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  56. {
  57. TrainDir = "mnist",
  58. OneHot = false,
  59. ValidationSize = 55000,
  60. }).Result;
  61. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  62. }
  63. [Ignore]
  64. [TestMethod]
  65. public void VGG19()
  66. {
  67. var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
  68. model.summary();
  69. var classify_model = keras.Sequential(new System.Collections.Generic.List<ILayer>()
  70. {
  71. model,
  72. keras.layers.Flatten(),
  73. keras.layers.Dense(10),
  74. });
  75. classify_model.summary();
  76. classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  77. var x = np.random.uniform(0, 1, (8, 512, 512, 3));
  78. var y = np.ones(8);
  79. classify_model.fit(x, y, batch_size: 4);
  80. }
  81. [Ignore]
  82. [TestMethod]
  83. public void TestModelBeforeTF2_5()
  84. {
  85. var a = keras.layers;
  86. var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
  87. model.summary();
  88. }
  89. }