You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiInputModelTest.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. using Microsoft.VisualStudio.TestPlatform.Utilities;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using System.Xml.Linq;
  9. using Tensorflow.Operations;
  10. using static Tensorflow.Binding;
  11. using static Tensorflow.KerasApi;
  12. using Tensorflow.NumPy;
  13. using Microsoft.VisualBasic;
  14. using static HDF.PInvoke.H5T;
  15. using Tensorflow.Keras.UnitTest.Helpers;
  16. using Tensorflow.Keras.Optimizers;
  17. namespace Tensorflow.Keras.UnitTest
  18. {
  19. [TestClass]
  20. public class MultiInputModelTest
  21. {
  22. [TestMethod]
  23. public void SimpleModel()
  24. {
  25. var inputs = keras.Input((28, 28, 1));
  26. var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
  27. var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
  28. var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
  29. var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
  30. var flat1 = keras.layers.Flatten().Apply(pool2);
  31. var inputs_2 = keras.Input((28, 28, 1));
  32. var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
  33. var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
  34. var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
  35. var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
  36. var flat1_2 = keras.layers.Flatten().Apply(pool2_2);
  37. var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
  38. var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
  39. var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
  40. var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
  41. var output = keras.layers.Softmax(-1).Apply(dense3);
  42. var model = keras.Model((inputs, inputs_2), output);
  43. model.summary();
  44. var data_loader = new MnistModelLoader();
  45. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  46. {
  47. TrainDir = "mnist",
  48. OneHot = false,
  49. ValidationSize = 59000,
  50. }).Result;
  51. var loss = keras.losses.SparseCategoricalCrossentropy();
  52. var optimizer = new Adam(0.001f);
  53. model.compile(optimizer, loss, new string[] { "accuracy" });
  54. NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
  55. NDArray x2 = x1;
  56. var x = new NDArray[] { x1, x2 };
  57. model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
  58. }
  59. }
  60. }