You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

EarlystoppingTest.cs 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using Tensorflow.Keras.Callbacks;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.NumPy;
  6. using static Tensorflow.KerasApi;
  7. namespace Tensorflow.Keras.UnitTest.Callbacks
  8. {
  9. [TestClass]
  10. public class EarlystoppingTest
  11. {
  12. [TestMethod]
  13. // Because loading the weight variable into the model has not yet been implemented,
  14. // so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
  15. public void Earlystopping()
  16. {
  17. var layers = keras.layers;
  18. var model = keras.Sequential(new List<ILayer>
  19. {
  20. layers.Rescaling(1.0f / 255, input_shape: (28, 28, 1)),
  21. layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
  22. layers.MaxPooling2D(),
  23. layers.Flatten(),
  24. layers.Dense(128, activation: keras.activations.Relu),
  25. layers.Dense(10)
  26. });
  27. model.summary();
  28. model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
  29. loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
  30. metrics: new[] { "acc" });
  31. var num_epochs = 3;
  32. var batch_size = 8;
  33. var data_loader = new MnistModelLoader();
  34. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  35. {
  36. TrainDir = "mnist",
  37. OneHot = false,
  38. ValidationSize = 59900,
  39. }).Result;
  40. NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
  41. NDArray x2 = x1;
  42. var x = new NDArray[] { x1, x2 };
  43. // define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
  44. CallbackParams callback_parameters = new CallbackParams
  45. {
  46. Model = model,
  47. Epochs = num_epochs,
  48. };
  49. // define your earlystop
  50. ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
  51. // define a callbcaklist, then add the earlystopping to it.
  52. var callbacks = new List<ICallback>{ earlystop};
  53. model.fit(x, dataset.Train.Labels, batch_size, num_epochs, callbacks: callbacks);
  54. }
  55. }
  56. }