You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 3.3 kB

6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Linq;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.Basics
  7. {
  8. [TestClass]
  9. public class TensorTest : GraphModeTestBase
  10. {
  11. [TestMethod, Ignore]
  12. public void sparse_to_dense()
  13. {
  14. var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 });
  15. var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1);
  16. var st = tf.concat(values: new[] { indices, labels }, axis: 1);
  17. var onehot = tf.sparse_to_dense(st, (5, 5), 1);
  18. var sess = tf.Session();
  19. var result = sess.run(onehot);
  20. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
  21. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
  22. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
  23. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
  24. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
  25. }
  26. [TestMethod, Ignore]
  27. public void sparse_tensor_to_dense()
  28. {
  29. var decoded_list = tf.SparseTensor(new[,]
  30. {
  31. { 0L, 0L },
  32. { 1L, 2L }
  33. },
  34. new int[] { 1, 2 },
  35. new[] { 3L, 4L });
  36. var onehot = tf.sparse_tensor_to_dense(decoded_list);
  37. var sess = tf.Session();
  38. var result = sess.run(onehot);
  39. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
  40. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
  41. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
  42. }
  43. [TestMethod]
  44. public void batch_to_space_nd()
  45. {
  46. var inputs = np.arange(24).reshape((4, 2, 3));
  47. var block_shape = new[] { 2, 2 };
  48. int[,] crops = { { 0, 0 }, { 0, 0 } };
  49. var tensor = tf.batch_to_space_nd(inputs, block_shape, crops);
  50. var sess = tf.Session();
  51. var result = sess.run(tensor);
  52. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>()));
  53. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>()));
  54. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>()));
  55. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
  56. }
  57. [TestMethod, Ignore]
  58. public void boolean_mask()
  59. {
  60. var tensor = new[] { 0, 1, 2, 3 };
  61. var mask = np.array(new[] { true, false, true, false });
  62. var masked = tf.boolean_mask(tensor, mask);
  63. var sess = tf.Session();
  64. var result = sess.run(masked);
  65. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
  66. }
  67. }
  68. }