You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SaveTest.cs 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. using Tensorflow.Keras;
  12. using Tensorflow.Keras.ArgsDefinition;
  13. using Tensorflow.Keras.Engine;
  14. using Tensorflow.Keras.Layers;
  15. using Tensorflow.Keras.Losses;
  16. using Tensorflow.Keras.Metrics;
  17. using Tensorflow.Keras.Optimizers;
  18. namespace TensorFlowNET.Keras.UnitTest;
  19. // class MNISTLoader
  20. // {
  21. // public MNISTLoader()
  22. // {
  23. // var mnist = new MnistModelLoader()
  24. //
  25. // }
  26. // }
  27. [TestClass]
  28. public class SaveTest
  29. {
  30. [TestMethod]
  31. public void Test()
  32. {
  33. var inputs = new KerasInterface().Input((28, 28, 1));
  34. var x = new Flatten(new FlattenArgs()).Apply(inputs);
  35. x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x);
  36. x = new LayersApi().Dense(units: 10).Apply(x);
  37. var outputs = new LayersApi().Softmax(axis: 1).Apply(x);
  38. var model = new KerasInterface().Model(inputs, outputs);
  39. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"});
  40. var data_loader = new MnistModelLoader();
  41. var num_epochs = 1;
  42. var batch_size = 50;
  43. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  44. {
  45. TrainDir = "mnist",
  46. OneHot = false,
  47. ValidationSize = 0,
  48. }).Result;
  49. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  50. model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb");
  51. }
  52. }