You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiInputModelTest.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow.Keras.Optimizers;
  4. using Tensorflow.NumPy;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. namespace Tensorflow.Keras.UnitTest
  8. {
  9. [TestClass]
  10. public class MultiInputModelTest
  11. {
  12. [TestMethod]
  13. public void LeNetModel()
  14. {
  15. var inputs = keras.Input((28, 28, 1));
  16. var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
  17. var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
  18. var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
  19. var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
  20. var flat1 = keras.layers.Flatten().Apply(pool2);
  21. var inputs_2 = keras.Input((28, 28, 1));
  22. var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
  23. var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
  24. var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
  25. var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
  26. var flat1_2 = keras.layers.Flatten().Apply(pool2_2);
  27. var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
  28. var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
  29. var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
  30. var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
  31. var output = keras.layers.Softmax(-1).Apply(dense3);
  32. var model = keras.Model((inputs, inputs_2), output);
  33. model.summary();
  34. var data_loader = new MnistModelLoader();
  35. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  36. {
  37. TrainDir = "mnist",
  38. OneHot = false,
  39. ValidationSize = 59900,
  40. }).Result;
  41. var loss = keras.losses.SparseCategoricalCrossentropy();
  42. var optimizer = new Adam(0.001f);
  43. model.compile(optimizer, loss, new string[] { "accuracy" });
  44. NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
  45. NDArray x2 = x1;
  46. var x = new NDArray[] { x1, x2 };
  47. model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
  48. x1 = x1["0:8"];
  49. x2 = x1;
  50. x = new NDArray[] { x1, x2 };
  51. var y = dataset.Train.Labels["0:8"];
  52. (model as Engine.Model).evaluate(x, y);
  53. x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
  54. x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
  55. var pred = model.predict((x1, x2));
  56. Console.WriteLine(pred);
  57. }
  58. [TestMethod]
  59. public void LeNetModelDataset()
  60. {
  61. var inputs = keras.Input((28, 28, 1));
  62. var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
  63. var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
  64. var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
  65. var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
  66. var flat1 = keras.layers.Flatten().Apply(pool2);
  67. var inputs_2 = keras.Input((28, 28, 1));
  68. var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
  69. var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
  70. var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
  71. var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
  72. var flat1_2 = keras.layers.Flatten().Apply(pool2_2);
  73. var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
  74. var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
  75. var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
  76. var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
  77. var output = keras.layers.Softmax(-1).Apply(dense3);
  78. var model = keras.Model((inputs, inputs_2), output);
  79. model.summary();
  80. var data_loader = new MnistModelLoader();
  81. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  82. {
  83. TrainDir = "mnist",
  84. OneHot = false,
  85. ValidationSize = 59900,
  86. }).Result;
  87. var loss = keras.losses.SparseCategoricalCrossentropy();
  88. var optimizer = new Adam(0.001f);
  89. model.compile(optimizer, loss, new string[] { "accuracy" });
  90. NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
  91. var multiInputDataset = tf.data.Dataset.zip(
  92. tf.data.Dataset.from_tensor_slices(x1),
  93. tf.data.Dataset.from_tensor_slices(x1),
  94. tf.data.Dataset.from_tensor_slices(dataset.Train.Labels)
  95. ).batch(8);
  96. multiInputDataset.FirstInputTensorCount = 2;
  97. model.fit(multiInputDataset, epochs: 3);
  98. x1 = x1["0:8"];
  99. multiInputDataset = tf.data.Dataset.zip(
  100. tf.data.Dataset.from_tensor_slices(x1),
  101. tf.data.Dataset.from_tensor_slices(x1),
  102. tf.data.Dataset.from_tensor_slices(dataset.Train.Labels["0:8"])
  103. ).batch(8);
  104. multiInputDataset.FirstInputTensorCount = 2;
  105. (model as Engine.Model).evaluate(multiInputDataset);
  106. x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
  107. var x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
  108. multiInputDataset = tf.data.Dataset.zip(
  109. tf.data.Dataset.from_tensor_slices(x1),
  110. tf.data.Dataset.from_tensor_slices(x2)
  111. ).batch(8);
  112. multiInputDataset.FirstInputTensorCount = 2;
  113. var pred = model.predict(multiInputDataset);
  114. Console.WriteLine(pred);
  115. }
  116. }
  117. }