You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

EarlystoppingTest.cs 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.Keras.UnitTest.Helpers;
  3. using static Tensorflow.Binding;
  4. using Tensorflow;
  5. using Tensorflow.Keras.Optimizers;
  6. using Tensorflow.Keras.Callbacks;
  7. using Tensorflow.Keras.Engine;
  8. using System.Collections.Generic;
  9. using static Tensorflow.KerasApi;
  10. using Tensorflow.Keras;
  11. namespace TensorFlowNET.Keras.UnitTest
  12. {
  13. [TestClass]
  14. public class EarlystoppingTest
  15. {
  16. [TestMethod]
  17. // Because loading the weight variable into the model has not yet been implemented,
  18. // so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
  19. public void Earlystopping()
  20. {
  21. var layers = keras.layers;
  22. var model = keras.Sequential(new List<ILayer>
  23. {
  24. layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
  25. layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
  26. layers.MaxPooling2D(),
  27. layers.Flatten(),
  28. layers.Dense(128, activation: keras.activations.Relu),
  29. layers.Dense(10)
  30. });
  31. model.summary();
  32. model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
  33. loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
  34. metrics: new[] { "acc" });
  35. var num_epochs = 3;
  36. var batch_size = 8;
  37. var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
  38. x_train = x_train / 255.0f;
  39. // define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
  40. CallbackParams callback_parameters = new CallbackParams
  41. {
  42. Model = model,
  43. Epochs = num_epochs,
  44. };
  45. // define your earlystop
  46. ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
  47. // define a callbcaklist, then add the earlystopping to it.
  48. var callbacks = new List<ICallback>();
  49. callbacks.add(earlystop);
  50. model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks);
  51. }
  52. }
  53. }