You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 5.8 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using Tensorflow;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class TensorTest : CApiTest
  13. {
  14. /// <summary>
  15. /// Port from c_api_test.cc
  16. /// `TEST(CAPI, AllocateTensor)`
  17. /// </summary>
  18. [TestMethod]
  19. public void c_api_AllocateTensor()
  20. {
  21. ulong num_bytes = 6 * sizeof(float);
  22. long[] dims = { 2, 3 };
  23. Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
  24. EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
  25. EXPECT_EQ(2, t.NDims);
  26. Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape));
  27. EXPECT_EQ(num_bytes, t.bytesize);
  28. t.Dispose();
  29. }
  30. /// <summary>
  31. /// Port from c_api_test.cc
  32. /// `TEST(CAPI, Tensor)`
  33. /// </summary>
  34. [TestMethod]
  35. public void c_api_Tensor()
  36. {
  37. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  38. var tensor = new Tensor(nd);
  39. var array = tensor.Data<float>();
  40. EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
  41. EXPECT_EQ(tensor.rank, nd.ndim);
  42. EXPECT_EQ(tensor.shape[0], nd.shape[0]);
  43. EXPECT_EQ(tensor.shape[1], nd.shape[1]);
  44. EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float));
  45. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
  46. }
  47. /// <summary>
  48. /// Port from tensorflow\c\c_api_test.cc
  49. /// `TEST(CAPI, SetShape)`
  50. /// </summary>
  51. [TestMethod]
  52. public void c_api_SetShape()
  53. {
  54. var s = new Status();
  55. var graph = new Graph();
  56. var feed = c_test_util.Placeholder(graph, s);
  57. var feed_out_0 = new TF_Output(feed, 0);
  58. // Fetch the shape, it should be completely unknown.
  59. int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  60. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  61. EXPECT_EQ(-1, num_dims);
  62. // Set the shape to be unknown, expect no change.
  63. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  64. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  65. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  66. EXPECT_EQ(-1, num_dims);
  67. // Set the shape to be 2 x Unknown
  68. long[] dims = { 2, -1 };
  69. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  70. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  71. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  72. EXPECT_EQ(2, num_dims);
  73. // Get the dimension vector appropriately.
  74. var returned_dims = new long[dims.Length];
  75. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  76. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  77. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  78. // Set to a new valid shape: [2, 3]
  79. dims[1] = 3;
  80. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  81. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  82. // Fetch and see that the new value is returned.
  83. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  84. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  85. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  86. // Try to set 'unknown' with unknown rank on the shape and see that
  87. // it doesn't change.
  88. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  89. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  90. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  91. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  92. EXPECT_EQ(2, num_dims);
  93. EXPECT_EQ(2, returned_dims[0]);
  94. EXPECT_EQ(3, returned_dims[1]);
  95. // Try to set 'unknown' with same rank on the shape and see that
  96. // it doesn't change.
  97. dims[0] = -1;
  98. dims[1] = -1;
  99. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  100. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  101. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  102. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  103. EXPECT_EQ(2, num_dims);
  104. EXPECT_EQ(2, returned_dims[0]);
  105. EXPECT_EQ(3, returned_dims[1]);
  106. // Try to fetch a shape with the wrong num_dims
  107. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
  108. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  109. // Try to set an invalid shape (cannot change 2x3 to a 2x5).
  110. dims[1] = 5;
  111. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  112. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  113. // Test for a scalar.
  114. var three = c_test_util.ScalarConst(3, graph, s);
  115. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  116. var three_out_0 = new TF_Output(three, 0);
  117. num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
  118. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  119. EXPECT_EQ(0, num_dims);
  120. c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
  121. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  122. // graph.Dispose();
  123. s.Dispose();
  124. }
  125. }
  126. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)