You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiGradientsTest.cs 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. using Tensorflow;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. /// <summary>
  11. /// tensorflow\c\c_api_test.cc
  12. /// `class CApiGradientsTest`
  13. /// </summary>
  14. [TestClass]
  15. public class CApiGradientsTest : CApiTest, IDisposable
  16. {
  17. private Graph graph_ = new Graph();
  18. private Graph expected_graph_ = new Graph();
  19. private Status s_ = new Status();
  20. private void TestGradientsSuccess(bool grad_inputs_provided)
  21. {
  22. var inputs = new TF_Output[2];
  23. var outputs = new TF_Output[1];
  24. var grad_outputs = new TF_Output[2];
  25. var expected_grad_outputs = new TF_Output[2];
  26. BuildSuccessGraph(inputs, outputs);
  27. }
  28. private void BuildSuccessGraph(TF_Output[] inputs, TF_Output[] outputs)
  29. {
  30. // Construct the following graph:
  31. // |
  32. // z|
  33. // |
  34. // MatMul
  35. // / \
  36. // ^ ^
  37. // | |
  38. // x| y|
  39. // | |
  40. // | |
  41. // Const_0 Const_1
  42. //
  43. var const0_val = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
  44. var const1_val = new float[] { 1.0f, 0.0f, 0.0f, 1.0f };
  45. var const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
  46. var const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
  47. var matmul = MatMul(graph_, s_, const0, const1, "MatMul");
  48. inputs[0] = new TF_Output(const0, 0);
  49. inputs[1] = new TF_Output(const1, 0);
  50. outputs[0] = new TF_Output(matmul, 0);
  51. EXPECT_EQ(TF_OK, TF_GetCode(s_));
  52. }
  53. private Operation FloatConst2x2(Graph graph, Status s, float[] values, string name)
  54. {
  55. var tensor = FloatTensor2x2(values);
  56. var desc = TF_NewOperation(graph, "Const", name);
  57. TF_SetAttrTensor(desc, "value", tensor, s);
  58. if (TF_GetCode(s) != TF_OK) return IntPtr.Zero;
  59. TF_SetAttrType(desc, "dtype", TF_FLOAT);
  60. var op = TF_FinishOperation(desc, s);
  61. EXPECT_EQ(TF_OK, TF_GetCode(s));
  62. return op;
  63. }
  64. private Tensor FloatTensor2x2(float[] values)
  65. {
  66. long[] dims = { 2, 2 };
  67. Tensor t = c_api.TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
  68. Marshal.Copy(values, 0, t, 4);
  69. return t;
  70. }
  71. private Operation MatMul(Graph graph, Status s, Operation l, Operation r, string name,
  72. bool transpose_a = false, bool transpose_b = false)
  73. {
  74. var desc = TF_NewOperation(graph, "MatMul", name);
  75. if (transpose_a)
  76. {
  77. TF_SetAttrBool(desc, "transpose_a", true);
  78. }
  79. if (transpose_b)
  80. {
  81. TF_SetAttrBool(desc, "transpose_b", true);
  82. }
  83. TF_AddInput(desc, new TF_Output(l, 0));
  84. TF_AddInput(desc, new TF_Output(r, 0));
  85. var op = TF_FinishOperation(desc, s);
  86. EXPECT_EQ(TF_OK, TF_GetCode(s));
  87. return op;
  88. }
  89. [TestMethod]
  90. public void Gradients_GradInputs()
  91. {
  92. TestGradientsSuccess(true);
  93. }
  94. [TestMethod]
  95. public void Gradients_NoGradInputs()
  96. {
  97. TestGradientsSuccess(false);
  98. }
  99. public void Dispose()
  100. {
  101. graph_.Dispose();
  102. expected_graph_.Dispose();
  103. s_.Dispose();
  104. }
  105. }
  106. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。