You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelLoad.cs 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Linq;
  4. using Tensorflow;
  5. using Tensorflow.Keras.Optimizers;
  6. using Tensorflow.Keras.UnitTest.Helpers;
  7. using Tensorflow.NumPy;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  10. [TestClass]
  11. public class SequentialModelLoad
  12. {
  13. [TestMethod]
  14. public void SimpleModelFromAutoCompile()
  15. {
  16. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  17. model.summary();
  18. model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  19. // check the weights
  20. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  21. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  22. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  23. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  24. var data_loader = new MnistModelLoader();
  25. var num_epochs = 1;
  26. var batch_size = 8;
  27. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  28. {
  29. TrainDir = "mnist",
  30. OneHot = false,
  31. ValidationSize = 58000,
  32. }).Result;
  33. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  34. }
  35. [TestMethod]
  36. public void AlexnetFromSequential()
  37. {
  38. new SequentialModelSave().AlexnetFromSequential();
  39. var model = tf.keras.models.load_model(@"./alexnet_from_sequential");
  40. model.summary();
  41. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  42. var num_epochs = 1;
  43. var batch_size = 8;
  44. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  45. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  46. }
  47. [TestMethod]
  48. public void Temp()
  49. {
  50. var model = tf.keras.models.load_model(@"Assets/python_func_model");
  51. model.summary();
  52. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  53. var data_loader = new MnistModelLoader();
  54. var num_epochs = 1;
  55. var batch_size = 8;
  56. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  57. {
  58. TrainDir = "mnist",
  59. OneHot = false,
  60. ValidationSize = 55000,
  61. }).Result;
  62. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  63. }
  64. }