You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

EarlystoppingTest.cs 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using Tensorflow.Keras.Callbacks;
  4. using Tensorflow.Keras.Engine;
  5. using static Tensorflow.KerasApi;
  6. namespace Tensorflow.Keras.UnitTest.Callbacks
  7. {
  8. [TestClass]
  9. public class EarlystoppingTest
  10. {
  11. [TestMethod]
  12. // Because loading the weight variable into the model has not yet been implemented,
  13. // so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
  14. public void Earlystopping()
  15. {
  16. var layers = keras.layers;
  17. var model = keras.Sequential(new List<ILayer>
  18. {
  19. layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
  20. layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
  21. layers.MaxPooling2D(),
  22. layers.Flatten(),
  23. layers.Dense(128, activation: keras.activations.Relu),
  24. layers.Dense(10)
  25. });
  26. model.summary();
  27. model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
  28. loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
  29. metrics: new[] { "acc" });
  30. var num_epochs = 3;
  31. var batch_size = 8;
  32. var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
  33. x_train = x_train / 255.0f;
  34. // define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
  35. CallbackParams callback_parameters = new CallbackParams
  36. {
  37. Model = model,
  38. Epochs = num_epochs,
  39. };
  40. // define your earlystop
  41. ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
  42. // define a callbcaklist, then add the earlystopping to it.
  43. var callbacks = new List<ICallback>();
  44. callbacks.add(earlystop);
  45. model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks);
  46. }
  47. }
  48. }