You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 3.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Linq;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.Basics
  7. {
  8. [TestClass]
  9. public class TensorTest : GraphModeTestBase
  10. {
  11. [TestMethod, Ignore]
  12. public void sparse_to_dense()
  13. {
  14. var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 });
  15. var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1);
  16. var st = tf.concat(values: new[] { indices, labels }, axis: 1);
  17. var onehot = tf.sparse_to_dense(st, (5, 5), 1);
  18. using (var sess = tf.Session())
  19. {
  20. var result = sess.run(onehot);
  21. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
  22. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
  23. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
  25. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
  26. };
  27. }
  28. [TestMethod, Ignore]
  29. public void sparse_tensor_to_dense()
  30. {
  31. var decoded_list = tf.SparseTensor(new[,]
  32. {
  33. { 0L, 0L },
  34. { 1L, 2L }
  35. },
  36. new int[] { 1, 2 },
  37. new[] { 3L, 4L });
  38. var onehot = tf.sparse_tensor_to_dense(decoded_list);
  39. using (var sess = tf.Session())
  40. {
  41. var result = sess.run(onehot);
  42. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
  43. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
  44. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
  45. }
  46. }
  47. [TestMethod]
  48. public void batch_to_space_nd()
  49. {
  50. var inputs = np.arange(24).reshape((4, 2, 3));
  51. var block_shape = new[] { 2, 2 };
  52. int[,] crops = { { 0, 0 }, { 0, 0 } };
  53. var tensor = tf.batch_to_space_nd(inputs, block_shape, crops);
  54. using (var sess = tf.Session())
  55. {
  56. var result = sess.run(tensor);
  57. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>()));
  58. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>()));
  59. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>()));
  60. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
  61. }
  62. }
  63. [TestMethod, Ignore]
  64. public void boolean_mask()
  65. {
  66. var tensor = new[] { 0, 1, 2, 3 };
  67. var mask = np.array(new[] { true, false, true, false });
  68. var masked = tf.boolean_mask(tensor, mask);
  69. using (var sess = tf.Session())
  70. {
  71. var result = sess.run(masked);
  72. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
  73. }
  74. }
  75. }
  76. }