You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelLoad.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Linq;
  4. using Tensorflow;
  5. using Tensorflow.Keras.Engine;
  6. using Tensorflow.Keras.Optimizers;
  7. using Tensorflow.Keras.UnitTest.Helpers;
  8. using Tensorflow.NumPy;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  12. [TestClass]
  13. public class SequentialModelLoad
  14. {
  15. [TestMethod]
  16. public void SimpleModelFromAutoCompile()
  17. {
  18. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  19. model.summary();
  20. model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  21. // check the weights
  22. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  23. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  24. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  25. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  26. var data_loader = new MnistModelLoader();
  27. var num_epochs = 1;
  28. var batch_size = 8;
  29. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  30. {
  31. TrainDir = "mnist",
  32. OneHot = false,
  33. ValidationSize = 58000,
  34. }).Result;
  35. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  36. }
  37. [TestMethod]
  38. public void AlexnetFromSequential()
  39. {
  40. new SequentialModelSave().AlexnetFromSequential();
  41. var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
  42. model.summary();
  43. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  44. var num_epochs = 1;
  45. var batch_size = 8;
  46. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  47. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  48. }
  49. [TestMethod]
  50. public void ModelWithSelfDefinedModule()
  51. {
  52. var model = tf.keras.models.load_model(@"Assets/python_func_model");
  53. model.summary();
  54. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  55. var data_loader = new MnistModelLoader();
  56. var num_epochs = 1;
  57. var batch_size = 8;
  58. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  59. {
  60. TrainDir = "mnist",
  61. OneHot = false,
  62. ValidationSize = 55000,
  63. }).Result;
  64. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  65. }
  66. [Ignore]
  67. [TestMethod]
  68. public void VGG19()
  69. {
  70. var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
  71. model.summary();
  72. var classify_model = keras.Sequential(new System.Collections.Generic.List<Tensorflow.Keras.ILayer>()
  73. {
  74. model,
  75. keras.layers.Flatten(),
  76. keras.layers.Dense(10),
  77. });
  78. classify_model.summary();
  79. classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  80. var x = np.random.uniform(0, 1, (8, 512, 512, 3));
  81. var y = np.ones((8));
  82. classify_model.fit(x, y, batch_size: 4);
  83. }
  84. [Ignore]
  85. [TestMethod]
  86. public void TestModelBeforeTF2_5()
  87. {
  88. var a = keras.layers;
  89. var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model;
  90. model.summary();
  91. }
  92. }