You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.Execute_MatMul_CPU.cs 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.NativeAPI
  7. {
  8. public partial class CApiEagerTest
  9. {
  10. /// <summary>
  11. /// TEST(CAPI, Execute_MatMul_CPU)
  12. /// </summary>
  13. [TestMethod]
  14. public unsafe void Execute_MatMul_CPU()
  15. {
  16. Execute_MatMul_CPU(false);
  17. }
  18. unsafe void Execute_MatMul_CPU(bool async)
  19. {
  20. using var status = TF_NewStatus();
  21. static SafeContextHandle NewContext(bool async, SafeStatusHandle status)
  22. {
  23. using var opts = c_api.TFE_NewContextOptions();
  24. c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async));
  25. return c_api.TFE_NewContext(opts, status);
  26. }
  27. IntPtr t;
  28. using (var ctx = NewContext(async, status))
  29. {
  30. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  31. var retvals = new SafeTensorHandleHandle[2];
  32. try
  33. {
  34. using (var m = TestMatrixTensorHandle())
  35. using (var matmul = MatMulOp(ctx, m, m))
  36. {
  37. int num_retvals;
  38. c_api.TFE_Execute(matmul, retvals, out num_retvals, status);
  39. EXPECT_EQ(1, num_retvals);
  40. EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  41. }
  42. t = TFE_TensorHandleResolve(retvals[0], status);
  43. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  44. }
  45. finally
  46. {
  47. retvals[0]?.Dispose();
  48. }
  49. }
  50. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  51. var product = new float[4];
  52. EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t));
  53. tf.memcpy(product, TF_TensorData(t), TF_TensorByteSize(t));
  54. c_api.TF_DeleteTensor(t);
  55. EXPECT_EQ(7f, product[0]);
  56. EXPECT_EQ(10f, product[1]);
  57. EXPECT_EQ(15f, product[2]);
  58. EXPECT_EQ(22f, product[3]);
  59. }
  60. }
  61. }