You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelLoad.cs 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow.Keras.Engine;
  9. using Tensorflow.Keras.Saving.SavedModel;
  10. using Tensorflow.Keras.Losses;
  11. using Tensorflow.Keras.Metrics;
  12. using Tensorflow;
  13. using Tensorflow.Keras.Optimizers;
  14. using static Tensorflow.KerasApi;
  15. using Tensorflow.NumPy;
  16. using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave;
  17. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  18. [TestClass]
  19. public class SequentialModelLoad
  20. {
  21. [TestMethod]
  22. public void SimpleModelFromAutoCompile()
  23. {
  24. var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  25. model.summary();
  26. model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  27. // check the weights
  28. var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
  29. var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");
  30. Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
  31. Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));
  32. var data_loader = new MnistModelLoader();
  33. var num_epochs = 1;
  34. var batch_size = 8;
  35. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  36. {
  37. TrainDir = "mnist",
  38. OneHot = false,
  39. ValidationSize = 50000,
  40. }).Result;
  41. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  42. }
  43. [TestMethod]
  44. public void AlexnetFromSequential()
  45. {
  46. new SequentialModelSave().AlexnetFromSequential();
  47. var model = keras.models.load_model(@"./alexnet_from_sequential");
  48. model.summary();
  49. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  50. var num_epochs = 1;
  51. var batch_size = 8;
  52. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  53. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  54. }
  55. }