You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiThreadsTest.cs 2.9 kB

2 years ago
2 years ago
2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Threading.Tasks;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.NumPy;
  6. using static Tensorflow.Binding;
  7. using static Tensorflow.KerasApi;
  8. namespace Tensorflow.Keras.UnitTest
  9. {
  10. [TestClass]
  11. public class MultiThreads
  12. {
  13. [TestMethod, Ignore("Failed on MacOS")]
  14. public void Test1()
  15. {
  16. //Arrange
  17. string savefile = "mymodel.h5";
  18. var model1 = BuildModel();
  19. model1.save_weights(savefile);
  20. var model2 = BuildModel();
  21. //act
  22. model1.load_weights(savefile);
  23. model2.load_weights(savefile);
  24. }
  25. [TestMethod, Ignore("Failed on MacOS")]
  26. public void Test2()
  27. {
  28. //Arrange
  29. string savefile = "mymodel2.h5";
  30. var model1 = BuildModel();
  31. model1.save_weights(savefile);
  32. model1 = BuildModel(); //recreate model
  33. //act
  34. model1.load_weights(savefile);
  35. }
  36. [TestMethod, Ignore("Failed on MacOS")]
  37. public void Test3Multithreading()
  38. {
  39. //Arrange
  40. string savefile = "mymodel3.h5";
  41. var model = BuildModel();
  42. model.save_weights(savefile);
  43. //Sanity check without multithreading
  44. for (int i = 0; i < 2; i++)
  45. {
  46. var clone = BuildModel();
  47. clone.load_weights(savefile);
  48. //Predict something
  49. clone.predict(np.array(new float[,] { { 0, 0 } }));
  50. } //works
  51. //act
  52. ParallelOptions parallelOptions = new ParallelOptions();
  53. parallelOptions.MaxDegreeOfParallelism = 8;
  54. var input = np.array(new float[,] { { 0, 0 } });
  55. Parallel.For(0, 8, parallelOptions, i =>
  56. {
  57. var clone = BuildModel();
  58. clone.load_weights(savefile);
  59. //Predict something
  60. clone.predict(input);
  61. });
  62. }
  63. IModel BuildModel()
  64. {
  65. tf.Context.reset_context();
  66. var inputs = keras.Input(shape: 2);
  67. // 1st dense layer
  68. var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid);
  69. var outputs = DenseLayer.Apply(inputs);
  70. // build keras model
  71. var model = tf.keras.Model(inputs, outputs, name: Guid.NewGuid().ToString());
  72. // show model summary
  73. model.summary();
  74. // compile keras model into tensorflow's static graph
  75. model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()),
  76. optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()),
  77. metrics: new[] { "accuracy" });
  78. return model;
  79. }
  80. }
  81. }