You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelLoad.cs 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow.Keras.Engine;
  9. using Tensorflow.Keras.Saving.SavedModel;
  10. using Tensorflow.Keras.Losses;
  11. using Tensorflow.Keras.Metrics;
  12. using Tensorflow;
  13. using Tensorflow.Keras.Optimizers;
  14. using static Tensorflow.KerasApi;
  15. using Tensorflow.NumPy;
  16. using Tensorflow.Keras.UnitTest.Helpers;
  17. using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave;
  18. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  19. [TestClass]
  20. public class SequentialModelLoad
  21. {
  22. [TestMethod]
  23. public void SimpleModelFromAutoCompile()
  24. {
  25. var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  26. model.summary();
  27. model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  28. // check the weights
  29. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  30. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  31. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  32. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  33. var data_loader = new MnistModelLoader();
  34. var num_epochs = 1;
  35. var batch_size = 8;
  36. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  37. {
  38. TrainDir = "mnist",
  39. OneHot = false,
  40. ValidationSize = 50000,
  41. }).Result;
  42. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  43. }
  44. [TestMethod]
  45. public void AlexnetFromSequential()
  46. {
  47. new SequentialModelSave().AlexnetFromSequential();
  48. var model = keras.models.load_model(@"./alexnet_from_sequential");
  49. model.summary();
  50. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  51. var num_epochs = 1;
  52. var batch_size = 8;
  53. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  54. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  55. }
  56. }