You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 3.6 kB

6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. using FluentAssertions;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using Tensorflow.NumPy;
  4. using System;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using Tensorflow;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass, Ignore]
  12. public class TensorTest
  13. {
  14. [TestMethod]
  15. public void sparse_to_dense()
  16. {
  17. var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 });
  18. var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1);
  19. var st = tf.concat(values: new[] { indices, labels }, axis: 1);
  20. var onehot = tf.sparse_to_dense(st, (5, 5), 1);
  21. using (var sess = tf.Session())
  22. {
  23. var result = sess.run(onehot);
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
  25. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
  26. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
  27. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
  28. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
  29. };
  30. }
  31. [TestMethod]
  32. public void sparse_tensor_to_dense()
  33. {
  34. var decoded_list = tf.SparseTensor(new[,]
  35. {
  36. { 0L, 0L },
  37. { 1L, 2L }
  38. },
  39. new int[] { 1, 2 },
  40. new[] { 3L, 4L });
  41. var onehot = tf.sparse_tensor_to_dense(decoded_list);
  42. using (var sess = tf.Session())
  43. {
  44. var result = sess.run(onehot);
  45. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
  46. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
  47. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
  48. }
  49. }
  50. [TestMethod]
  51. public void batch_to_space_nd()
  52. {
  53. var inputs = np.arange(24).reshape((4, 2, 3));
  54. var block_shape = new[] { 2, 2 };
  55. int[,] crops = { { 0, 0 }, { 0, 0 } };
  56. var tensor = tf.batch_to_space_nd(inputs, block_shape, crops);
  57. using (var sess = tf.Session())
  58. {
  59. var result = sess.run(tensor);
  60. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>()));
  61. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>()));
  62. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>()));
  63. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
  64. }
  65. }
  66. [TestMethod, Ignore]
  67. public void boolean_mask()
  68. {
  69. var tensor = new[] { 0, 1, 2, 3 };
  70. var mask = np.array(new[] { true, false, true, false });
  71. var masked = tf.boolean_mask(tensor, mask);
  72. using (var sess = tf.Session())
  73. {
  74. var result = sess.run(masked);
  75. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
  76. }
  77. }
  78. }
  79. }