You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 3.3 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Linq;
  5. using static Tensorflow.Binding;
  6. using Tensorflow;
  7. namespace TensorFlowNET.UnitTest.Basics
  8. {
  9. [TestClass]
  10. public class TensorTest : GraphModeTestBase
  11. {
  12. [TestMethod, Ignore]
  13. public void sparse_to_dense()
  14. {
  15. var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 });
  16. var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1);
  17. var st = tf.concat(values: new[] { indices, labels }, axis: 1);
  18. var onehot = tf.sparse_to_dense(st, (5, 5), 1);
  19. var sess = tf.Session();
  20. var result = sess.run(onehot);
  21. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
  22. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
  23. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
  25. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
  26. }
  27. [TestMethod, Ignore]
  28. public void sparse_tensor_to_dense()
  29. {
  30. var decoded_list = tf.SparseTensor(new[,]
  31. {
  32. { 0L, 0L },
  33. { 1L, 2L }
  34. },
  35. new int[] { 1, 2 },
  36. new[] { 3L, 4L });
  37. var onehot = tf.sparse_tensor_to_dense(decoded_list);
  38. var sess = tf.Session();
  39. var result = sess.run(onehot);
  40. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
  41. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
  42. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
  43. }
  44. [TestMethod]
  45. public void batch_to_space_nd()
  46. {
  47. var inputs = np.arange(24).reshape((4, 2, 3));
  48. var block_shape = new[] { 2, 2 };
  49. int[,] crops = { { 0, 0 }, { 0, 0 } };
  50. var tensor = tf.batch_to_space_nd(inputs, block_shape, crops);
  51. var sess = tf.Session();
  52. var result = sess.run(tensor);
  53. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>()));
  54. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>()));
  55. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>()));
  56. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
  57. }
  58. [TestMethod]
  59. public void boolean_mask()
  60. {
  61. if (!tf.executing_eagerly())
  62. tf.enable_eager_execution();
  63. var tensor = new[] { 0, 1, 2, 3 };
  64. var mask = np.array(new[] { true, false, true, false });
  65. var masked = tf.boolean_mask(tensor, mask);
  66. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
  67. }
  68. }
  69. }