You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 8.1 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Threading;
  7. using Tensorflow;
  8. using static Tensorflow.Python;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class TensorTest : CApiTest
  13. {
  14. //[TestMethod]
  15. public void TensorDeallocationThreadSafety()
  16. {
  17. var tensors = new Tensor[1000];
  18. foreach (var i in range(1000))
  19. {
  20. tensors[i] = new Tensor(new int[1000]);
  21. }
  22. SemaphoreSlim s = new SemaphoreSlim(0, 2);
  23. SemaphoreSlim s_done = new SemaphoreSlim(0, 2);
  24. var t1 = new Thread(() =>
  25. {
  26. s.Wait();
  27. foreach (var t in tensors)
  28. t.Dispose();
  29. s_done.Release();
  30. });
  31. var t2 = new Thread(() =>
  32. {
  33. s.Wait();
  34. foreach (var t in tensors)
  35. t.Dispose();
  36. s_done.Release();
  37. });
  38. t1.Start();
  39. t2.Start();
  40. s.Release(2);
  41. s_done.Wait();
  42. s_done.Wait();
  43. foreach (var t in tensors)
  44. Assert.IsTrue(t.IsDisposed);
  45. }
  46. [TestMethod]
  47. public unsafe void TensorFromFixed()
  48. {
  49. var array = new float[1000];
  50. var span = new Span<float>(array, 100, 500);
  51. fixed (float* ptr=&MemoryMarshal.GetReference(span))
  52. {
  53. using (var t = new Tensor((IntPtr)ptr, new long[] {span.Length}, tf.float32, 4*span.Length))
  54. {
  55. Assert.IsFalse(t.IsDisposed);
  56. Assert.IsFalse(t.IsMemoryOwner);
  57. Assert.AreEqual(2000, (int) t.bytesize);
  58. }
  59. }
  60. fixed (float* ptr = &array[0])
  61. {
  62. using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
  63. {
  64. Assert.IsFalse(t.IsDisposed);
  65. Assert.IsFalse(t.IsMemoryOwner);
  66. Assert.AreEqual(4000, (int)t.bytesize);
  67. }
  68. }
  69. }
  70. [TestMethod]
  71. public void AllocateTensor()
  72. {
  73. ulong num_bytes = 6 * sizeof(float);
  74. long[] dims = { 2, 3 };
  75. Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
  76. EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
  77. EXPECT_EQ(2, t.NDims);
  78. EXPECT_EQ((int)dims[0], t.shape[0]);
  79. EXPECT_EQ(num_bytes, t.bytesize);
  80. t.Dispose();
  81. }
  82. /// <summary>
  83. /// Port from c_api_test.cc
  84. /// `TEST(CAPI, MaybeMove)`
  85. /// </summary>
  86. [TestMethod]
  87. public void MaybeMove()
  88. {
  89. NDArray nd = np.array(2, 3);
  90. Tensor t = new Tensor(nd);
  91. Tensor o = t.MaybeMove();
  92. ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
  93. t.Dispose();
  94. }
  95. /// <summary>
  96. /// Port from c_api_test.cc
  97. /// `TEST(CAPI, Tensor)`
  98. /// </summary>
  99. [TestMethod]
  100. public void Tensor()
  101. {
  102. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  103. var tensor = new Tensor(nd);
  104. var array = tensor.Data<float>();
  105. EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
  106. EXPECT_EQ(tensor.rank, nd.ndim);
  107. EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
  108. EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
  109. EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
  110. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
  111. }
  112. /// <summary>
  113. /// Port from tensorflow\c\c_api_test.cc
  114. /// `TEST(CAPI, SetShape)`
  115. /// </summary>
  116. [TestMethod]
  117. public void SetShape()
  118. {
  119. var s = new Status();
  120. var graph = new Graph();
  121. var feed = c_test_util.Placeholder(graph, s);
  122. var feed_out_0 = new TF_Output(feed, 0);
  123. // Fetch the shape, it should be completely unknown.
  124. int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  125. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  126. EXPECT_EQ(-1, num_dims);
  127. // Set the shape to be unknown, expect no change.
  128. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  129. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  130. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  131. EXPECT_EQ(-1, num_dims);
  132. // Set the shape to be 2 x Unknown
  133. long[] dims = { 2, -1 };
  134. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  135. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  136. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  137. EXPECT_EQ(2, num_dims);
  138. // Get the dimension vector appropriately.
  139. var returned_dims = new long[dims.Length];
  140. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  141. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  142. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  143. // Set to a new valid shape: [2, 3]
  144. dims[1] = 3;
  145. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  146. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  147. // Fetch and see that the new value is returned.
  148. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  149. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  150. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  151. // Try to set 'unknown' with unknown rank on the shape and see that
  152. // it doesn't change.
  153. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  154. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  155. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  156. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  157. EXPECT_EQ(2, num_dims);
  158. EXPECT_EQ(2, (int)returned_dims[0]);
  159. EXPECT_EQ(3, (int)returned_dims[1]);
  160. // Try to set 'unknown' with same rank on the shape and see that
  161. // it doesn't change.
  162. dims[0] = -1;
  163. dims[1] = -1;
  164. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  165. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  166. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  167. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  168. EXPECT_EQ(2, num_dims);
  169. EXPECT_EQ(2, (int)returned_dims[0]);
  170. EXPECT_EQ(3, (int)returned_dims[1]);
  171. // Try to fetch a shape with the wrong num_dims
  172. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
  173. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  174. // Try to set an invalid shape (cannot change 2x3 to a 2x5).
  175. dims[1] = 5;
  176. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  177. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  178. // Test for a scalar.
  179. var three = c_test_util.ScalarConst(3, graph, s);
  180. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  181. var three_out_0 = new TF_Output(three, 0);
  182. num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
  183. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  184. EXPECT_EQ(0, num_dims);
  185. c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
  186. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  187. // graph.Dispose();
  188. s.Dispose();
  189. }
  190. }
  191. }