You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelBuildTest.cs 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using static Tensorflow.Binding;
  4. using static Tensorflow.KerasApi;
  5. namespace Tensorflow.Keras.UnitTest.Model
  6. {
  7. [TestClass]
  8. public class ModelBuildTest
  9. {
  10. [TestMethod]
  11. public void DenseBuild()
  12. {
  13. // two dimensions input with unknown batchsize
  14. var input = tf.keras.layers.Input((17, 60));
  15. var dense = tf.keras.layers.Dense(64);
  16. var output = dense.Apply(input);
  17. var model = tf.keras.Model(input, output);
  18. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  19. // one dimensions input with unknown batchsize
  20. var input_2 = tf.keras.layers.Input((60));
  21. var dense_2 = tf.keras.layers.Dense(64);
  22. var output_2 = dense_2.Apply(input_2);
  23. var model_2 = tf.keras.Model(input_2, output_2);
  24. model_2.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  25. // two dimensions input with specified batchsize
  26. var input_3 = tf.keras.layers.Input((17, 60), 8);
  27. var dense_3 = tf.keras.layers.Dense(64);
  28. var output_3 = dense_3.Apply(input_3);
  29. var model_3 = tf.keras.Model(input_3, output_3);
  30. model_3.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  31. // one dimensions input with specified batchsize
  32. var input_4 = tf.keras.layers.Input((60), 8);
  33. var dense_4 = tf.keras.layers.Dense(64);
  34. var output_4 = dense_4.Apply(input_4);
  35. var model_4 = tf.keras.Model(input_4, output_4);
  36. model_4.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  37. }
  38. [TestMethod]
  39. public void NestedSequential()
  40. {
  41. var block1 = keras.Sequential(new[] {
  42. keras.layers.InputLayer((3, 3)),
  43. keras.Sequential(new []
  44. {
  45. keras.layers.Flatten(),
  46. keras.layers.Dense(5)
  47. }
  48. )
  49. });
  50. block1.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
  51. var x = tf.ones((1, 3, 3));
  52. var y = block1.predict(x);
  53. Console.WriteLine(y);
  54. }
  55. }
  56. }