From 941a2410554a7eadf79c98fa989bc9784104ded7 Mon Sep 17 00:00:00 2001 From: Sam Harwell Date: Mon, 27 Jul 2020 07:59:52 -0700 Subject: [PATCH] Simplify code using SafeTensorHandleHandle --- src/TensorFlowNET.Core/Eager/c_api.eager.cs | 15 ++++++++++++++ .../Eager/CApi.Eager.Execute_MatMul_CPU.cs | 20 +++++++++---------- .../CApi.Eager.OpGetInputAndOutputLengths.cs | 5 +++-- ...pi.Eager.OpInferMixedTypeInputListAttrs.cs | 3 ++- .../Eager/CApi.Eager.TensorHandleDevices.cs | 3 ++- .../NativeAPI/Eager/CApi.Eager.Variables.cs | 3 ++- .../NativeAPI/Eager/CApi.Eager.cs | 11 ++++------ 7 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 4c44628e..df68060b 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -91,6 +91,18 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TFE_DeleteContext(IntPtr ctx); + /// + /// Execute the operation defined by and return handles to computed + /// tensors in . + /// + /// + /// Upon successful return, the first slots in will + /// contain handle instances which the caller is responsible for disposing once they are no longer in use. + /// + /// + /// + /// + /// public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) { unsafe @@ -100,6 +112,9 @@ namespace Tensorflow TFE_Execute(op, rawReturns, ref num_retvals, status); for (var i = 0; i < num_retvals; i++) { + // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be + // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return + // values. retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs index 2e77fe9d..441f9d27 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs @@ -34,23 +34,23 @@ namespace TensorFlowNET.UnitTest.NativeAPI CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); var retvals = new SafeTensorHandleHandle[2]; - try + using (var m = TestMatrixTensorHandle()) + using (var matmul = MatMulOp(ctx, m, m)) { - using (var m = TestMatrixTensorHandle()) - using (var matmul = MatMulOp(ctx, m, m)) - { - int num_retvals; - c_api.TFE_Execute(matmul, retvals, out num_retvals, status); - EXPECT_EQ(1, num_retvals); - EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - } + int num_retvals; + c_api.TFE_Execute(matmul, retvals, out num_retvals, status); + EXPECT_EQ(1, num_retvals); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + try + { t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); } finally { - retvals[0]?.Dispose(); + retvals[0].Dispose(); } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs index 187f71e5..7f7318c1 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs @@ -51,6 +51,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI int num_retvals; TFE_Execute(identityOp, retvals, out num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, num_retvals); try { @@ -62,8 +63,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI } finally { - retvals[0]?.Dispose(); - retvals[1]?.Dispose(); + retvals[0].Dispose(); + retvals[1].Dispose(); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs index 6e79c3e5..576cbde4 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs @@ -46,8 +46,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI int num_retvals; TFE_Execute(assertOp, retvals, out num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(1, num_retvals); - retvals[0]?.Dispose(); + retvals[0].Dispose(); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs index ae3e07f2..4153ac5c 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs @@ -49,6 +49,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI int num_retvals; c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + ASSERT_EQ(1, num_retvals); try { @@ -64,7 +65,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI } finally { - retvals[0]?.Dispose(); + retvals[0].Dispose(); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs index 7d3c4195..1d59a3d0 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs @@ -38,6 +38,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_OpAddInput(op, var_handle, status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); TFE_Execute(op, value_handle, out num_retvals, status); + ASSERT_EQ(1, num_retvals); } try @@ -57,7 +58,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI } finally { - value_handle[0]?.Dispose(); + value_handle[0].Dispose(); } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs index 82a1912a..b44d9cc0 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs @@ -90,11 +90,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_OpSetAttrString(op, "shared_name", "", 0); if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); TFE_Execute(op, var_handle, out num_retvals, status); + if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + CHECK_EQ(1, num_retvals); } - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); - CHECK_EQ(1, num_retvals); - // Assign 'value' to it. using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) { @@ -112,13 +111,11 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_OpAddInput(op, value_handle, status); if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); - num_retvals = 0; c_api.TFE_Execute(op, null, out num_retvals, status); + if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); + CHECK_EQ(0, num_retvals); } - if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); - CHECK_EQ(0, num_retvals); - return var_handle[0]; }