You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 6.3 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using Tensorflow;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class TensorTest : CApiTest
  13. {
  14. /// <summary>
  15. /// Port from c_api_test.cc
  16. /// `TEST(CAPI, AllocateTensor)`
  17. /// </summary>
  18. [TestMethod]
  19. public void AllocateTensor()
  20. {
  21. /*ulong num_bytes = 6 * sizeof(float);
  22. long[] dims = { 2, 3 };
  23. Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
  24. EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
  25. EXPECT_EQ(2, t.NDims);
  26. Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape));
  27. EXPECT_EQ(num_bytes, t.bytesize);
  28. t.Dispose();*/
  29. }
  30. /// <summary>
  31. /// Port from c_api_test.cc
  32. /// `TEST(CAPI, MaybeMove)`
  33. /// </summary>
  34. [TestMethod]
  35. public void MaybeMove()
  36. {
  37. NDArray nd = np.array(2, 3);
  38. Tensor t = new Tensor(nd);
  39. Tensor o = t.MaybeMove();
  40. ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
  41. t.Dispose();
  42. }
  43. /// <summary>
  44. /// Port from c_api_test.cc
  45. /// `TEST(CAPI, Tensor)`
  46. /// </summary>
  47. [TestMethod]
  48. public void Tensor()
  49. {
  50. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  51. var tensor = new Tensor(nd);
  52. var array = tensor.Data<float>();
  53. EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
  54. EXPECT_EQ(tensor.rank, nd.ndim);
  55. EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
  56. EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
  57. EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
  58. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
  59. }
  60. /// <summary>
  61. /// Port from tensorflow\c\c_api_test.cc
  62. /// `TEST(CAPI, SetShape)`
  63. /// </summary>
  64. [TestMethod]
  65. public void SetShape()
  66. {
  67. var s = new Status();
  68. var graph = new Graph();
  69. var feed = c_test_util.Placeholder(graph, s);
  70. var feed_out_0 = new TF_Output(feed, 0);
  71. // Fetch the shape, it should be completely unknown.
  72. int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  73. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  74. EXPECT_EQ(-1, num_dims);
  75. // Set the shape to be unknown, expect no change.
  76. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  77. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  78. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  79. EXPECT_EQ(-1, num_dims);
  80. // Set the shape to be 2 x Unknown
  81. long[] dims = { 2, -1 };
  82. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  83. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  84. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  85. EXPECT_EQ(2, num_dims);
  86. // Get the dimension vector appropriately.
  87. var returned_dims = new long[dims.Length];
  88. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  89. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  90. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  91. // Set to a new valid shape: [2, 3]
  92. dims[1] = 3;
  93. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  94. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  95. // Fetch and see that the new value is returned.
  96. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  97. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  98. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  99. // Try to set 'unknown' with unknown rank on the shape and see that
  100. // it doesn't change.
  101. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  102. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  103. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  104. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  105. EXPECT_EQ(2, num_dims);
  106. EXPECT_EQ(2, (int)returned_dims[0]);
  107. EXPECT_EQ(3, (int)returned_dims[1]);
  108. // Try to set 'unknown' with same rank on the shape and see that
  109. // it doesn't change.
  110. dims[0] = -1;
  111. dims[1] = -1;
  112. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  113. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  114. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  115. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  116. EXPECT_EQ(2, num_dims);
  117. EXPECT_EQ(2, (int)returned_dims[0]);
  118. EXPECT_EQ(3, (int)returned_dims[1]);
  119. // Try to fetch a shape with the wrong num_dims
  120. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
  121. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  122. // Try to set an invalid shape (cannot change 2x3 to a 2x5).
  123. dims[1] = 5;
  124. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  125. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  126. // Test for a scalar.
  127. var three = c_test_util.ScalarConst(3, graph, s);
  128. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  129. var three_out_0 = new TF_Output(three, 0);
  130. num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
  131. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  132. EXPECT_EQ(0, num_dims);
  133. c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
  134. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  135. // graph.Dispose();
  136. s.Dispose();
  137. }
  138. }
  139. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。