You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 3.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using Tensorflow;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class TensorTest
  13. {
  14. [TestMethod]
  15. public void NewTensor()
  16. {
  17. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  18. var tensor = new Tensor(nd);
  19. var array = tensor.Data<float>();
  20. Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT);
  21. Assert.AreEqual(tensor.rank, nd.ndim);
  22. Assert.AreEqual(tensor.shape[0], nd.shape[0]);
  23. Assert.AreEqual(tensor.shape[1], nd.shape[1]);
  24. Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float));
  25. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
  26. }
  27. /// <summary>
  28. /// Port from tensorflow\c\c_api_test.cc
  29. /// </summary>
  30. [TestMethod]
  31. public void SetShape()
  32. {
  33. var s = new Status();
  34. var graph = tf.get_default_graph();
  35. var desc = c_api.TF_NewOperation(graph, "Placeholder", "");
  36. c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT);
  37. //if (!dims.empty())
  38. {
  39. //TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
  40. }
  41. var op = c_api.TF_FinishOperation(desc, s);
  42. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  43. Assert.IsNotNull(op);
  44. // Fetch the shape, it should be completely unknown.
  45. var feed_out_0 = new TF_Output { oper = op, index = 0 };
  46. int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  47. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  48. Assert.AreEqual(-1, num_dims);
  49. // Set the shape to be unknown, expect no change.
  50. c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s);
  51. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  52. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  53. Assert.AreEqual(-1, num_dims);
  54. // Set the shape to be 2 x Unknown
  55. var dims = new int[] { 2, -1 };
  56. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  57. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  58. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  59. Assert.AreEqual(2, num_dims);
  60. // Get the dimension vector appropriately.
  61. var returned_dims = new int[dims.Length];
  62. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  63. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  64. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  65. // Set to a new valid shape: [2, 3]
  66. dims[1] = 3;
  67. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  68. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  69. // Fetch and see that the new value is returned.
  70. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  71. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  72. //Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  73. // Test for a scalar.
  74. var three = c_test_util.ScalarConst(3, graph, s);
  75. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  76. var three_out_0 = new TF_Output { oper = three.Handle };
  77. num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
  78. Assert.AreEqual(0, num_dims);
  79. }
  80. }
  81. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)