diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 230dcf97..8c99ac10 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -145,7 +145,7 @@ namespace Tensorflow.Eager return flat_result; } - TFE_Op GetOp(Context ctx, string op_or_function_name, Status status) + SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) { if (thread_local_eager_operation_map.find(ctx, out var op)) c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle); @@ -159,7 +159,7 @@ namespace Tensorflow.Eager return op; } - static UnorderedMap thread_local_eager_operation_map = new UnorderedMap(); + static UnorderedMap thread_local_eager_operation_map = new UnorderedMap(); bool HasAccumulator() { @@ -192,7 +192,7 @@ namespace Tensorflow.Eager ArgDef input_arg, List flattened_attrs, List flattened_inputs, - IntPtr op, + SafeOpHandle op, Status status) { IntPtr input_handle; @@ -224,7 +224,7 @@ namespace Tensorflow.Eager return true; } - public void SetOpAttrs(TFE_Op op, params object[] attrs) + public void SetOpAttrs(SafeOpHandle op, params object[] attrs) { var status = tf.status; var len = attrs.Length; @@ -257,7 +257,7 @@ namespace Tensorflow.Eager /// /// /// - void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, + void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr, string attr_name, object attr_value, Dictionary attr_list_sizes, Status status) @@ -290,7 +290,7 @@ namespace Tensorflow.Eager } } - bool SetOpAttrList(Context ctx, IntPtr op, + bool SetOpAttrList(Context ctx, SafeOpHandle op, string key, object value, TF_AttrType type, Dictionary attr_list_sizes, Status status) @@ -298,7 +298,7 @@ namespace Tensorflow.Eager return false; } - bool SetOpAttrScalar(Context ctx, IntPtr op, + bool SetOpAttrScalar(Context ctx, SafeOpHandle op, string key, object value, TF_AttrType type, Dictionary attr_list_sizes, Status status) diff --git a/src/TensorFlowNET.Core/Eager/SafeOpHandle.cs b/src/TensorFlowNET.Core/Eager/SafeOpHandle.cs new file mode 100644 index 00000000..2a50f412 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeOpHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeOpHandle : SafeTensorflowHandle + { + private SafeOpHandle() + { + } + + public SafeOpHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteOp(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/TFE_Op.cs b/src/TensorFlowNET.Core/Eager/TFE_Op.cs deleted file mode 100644 index e364f853..00000000 --- a/src/TensorFlowNET.Core/Eager/TFE_Op.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow.Eager -{ - public struct TFE_Op - { - IntPtr _handle; - - public TFE_Op(IntPtr handle) - => _handle = handle; - - public static implicit operator TFE_Op(IntPtr handle) - => new TFE_Op(handle); - - public static implicit operator IntPtr(TFE_Op tensor) - => tensor._handle; - - public override string ToString() - => $"TFE_Op 0x{_handle.ToString("x16")}"; - } -} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index f78ae02a..d115734b 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -30,7 +30,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); + public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); @@ -43,7 +43,7 @@ namespace Tensorflow /// const char* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status); + public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); /// /// Returns the length (number of tensors) of the output argument `output_name` @@ -54,7 +54,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status); + public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); /// /// @@ -65,7 +65,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); + public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); /// /// @@ -98,7 +98,7 @@ namespace Tensorflow /// int* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); + public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); /// /// @@ -108,7 +108,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); + public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); /// /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This @@ -124,7 +124,7 @@ namespace Tensorflow /// const char* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); + public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); /// /// @@ -140,10 +140,10 @@ namespace Tensorflow /// const char* /// TF_DataType [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); + public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value); [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); + public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); /// /// @@ -154,10 +154,10 @@ namespace Tensorflow /// const int /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); + public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrBool(IntPtr op, string attr_name, bool value); + public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); /// /// @@ -167,7 +167,7 @@ namespace Tensorflow /// const void* /// size_t [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); + public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); /// /// @@ -176,7 +176,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetDevice(TFE_Op op, string device_name, SafeStatusHandle status); + public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status); /// /// @@ -185,7 +185,7 @@ namespace Tensorflow /// TFE_TensorHandle* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status); + public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status); /// /// diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs index 5454d4ff..84a2f045 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs @@ -82,25 +82,25 @@ namespace TensorFlowNET.UnitTest protected ulong TF_TensorByteSize(IntPtr t) => c_api.TF_TensorByteSize(t); - protected void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status) + protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status) => c_api.TFE_OpAddInput(op, h, status); - protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value) + protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) => c_api.TFE_OpSetAttrType(op, attr_name, value); - protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) + protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); - protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) + protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length) => c_api.TFE_OpSetAttrString(op, attr_name, value, length); - protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) + protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) => c_api.TFE_NewOp(ctx, op_or_function_name, status); protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) => c_api.TFE_NewTensorHandle(t, status); - protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) + protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) => c_api.TFE_Execute(op, retvals, ref num_retvals, status); protected SafeContextOptionsHandle TFE_NewContextOptions() @@ -109,21 +109,18 @@ namespace TensorFlowNET.UnitTest protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) => c_api.TFE_NewContext(opts, status); - protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) + protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) => c_api.TFE_OpGetInputLength(op, input_name, status); - protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) + protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); - protected int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status) + protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) => c_api.TFE_OpGetOutputLength(op, input_name, status); protected void TFE_DeleteTensorHandle(IntPtr h) => c_api.TFE_DeleteTensorHandle(h); - protected void TFE_DeleteOp(IntPtr op) - => c_api.TFE_DeleteOp(op); - protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) => c_api.TFE_ContextGetExecutorForThread(ctx); @@ -154,7 +151,7 @@ namespace TensorFlowNET.UnitTest protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); - protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) + protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) => c_api.TFE_OpSetDevice(op, device_name, status); } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs index 3e6c930b..3a90c2f2 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs @@ -34,13 +34,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); var m = TestMatrixTensorHandle(); - var matmul = MatMulOp(ctx, m, m); var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; - int num_retvals = 2; - c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); - EXPECT_EQ(1, num_retvals); - EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteOp(matmul); + using (var matmul = MatMulOp(ctx, m, m)) + { + int num_retvals = 2; + c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); + EXPECT_EQ(1, num_retvals); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + TFE_DeleteTensorHandle(m); t = TFE_TensorHandleResolve(retvals[0], status); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs index 7b6c8f38..bae67a34 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs @@ -26,37 +26,38 @@ namespace TensorFlowNET.UnitTest.NativeAPI var input1 = TestMatrixTensorHandle(); var input2 = TestMatrixTensorHandle(); - var identityOp = TFE_NewOp(ctx, "IdentityN", status); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var retvals = new IntPtr[2]; + using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) + { + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - // Try to retrieve lengths before building the attributes (should fail) - EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); - CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); - EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); - CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); + // Try to retrieve lengths before building the attributes (should fail) + EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); - var inputs = new IntPtr[] { input1, input2 }; - TFE_OpAddInputList(identityOp, inputs, 2, status); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var inputs = new IntPtr[] { input1, input2 }; + TFE_OpAddInputList(identityOp, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - // Try to retrieve lengths before executing the op (should work) - EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + // Try to retrieve lengths before executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - var retvals = new IntPtr[2]; - int num_retvals = 2; - TFE_Execute(identityOp, retvals, ref num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + int num_retvals = 2; + TFE_Execute(identityOp, retvals, ref num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - // Try to retrieve lengths after executing the op (should work) - EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + // Try to retrieve lengths after executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } - TFE_DeleteOp(identityOp); TFE_DeleteTensorHandle(input1); TFE_DeleteTensorHandle(input2); TFE_DeleteTensorHandle(retvals[0]); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs index 9c903e4b..83d5902a 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs @@ -27,27 +27,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI var condition = TestScalarTensorHandle(true); var t1 = TestMatrixTensorHandle(); var t2 = TestAxisTensorHandle(); - var assertOp = TFE_NewOp(ctx, "Assert", status); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_OpAddInput(assertOp, condition, status); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - var data = new[] { condition, t1, t2 }; - TFE_OpAddInputList(assertOp, data, 3, status); - CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var retvals = new IntPtr[1]; + using (var assertOp = TFE_NewOp(ctx, "Assert", status)) + { + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(assertOp, condition, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var data = new[] { condition, t1, t2 }; + TFE_OpAddInputList(assertOp, data, 3, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - /*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; - var attr_found = attr_values.First(x => x.Name == "T"); - EXPECT_NE(attr_found, attr_values.Last());*/ - // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); - //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); - //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); + /*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; + var attr_found = attr_values.First(x => x.Name == "T"); + EXPECT_NE(attr_found, attr_values.Last());*/ + // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); + //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); + //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); - var retvals = new IntPtr[1]; - int num_retvals = 1; - TFE_Execute(assertOp, retvals, ref num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + int num_retvals = 1; + TFE_Execute(assertOp, retvals, ref num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } - TFE_DeleteOp(assertOp); TFE_DeleteTensorHandle(condition); TFE_DeleteTensorHandle(t1); TFE_DeleteTensorHandle(t2); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs index 6cee132f..d4dda7cc 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs @@ -40,25 +40,26 @@ namespace TensorFlowNET.UnitTest.NativeAPI var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); - var shape_op = ShapeOp(ctx, hgpu); - TFE_OpSetDevice(shape_op, gpu_device_name, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); var retvals = new IntPtr[1]; - int num_retvals = 1; - c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + using (var shape_op = ShapeOp(ctx, hgpu)) + { + TFE_OpSetDevice(shape_op, gpu_device_name, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + int num_retvals = 1; + c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); - // .device of shape is GPU since the op is executed on GPU - device_name = TFE_TensorHandleDeviceName(retvals[0], status); - ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - ASSERT_TRUE(device_name.Contains("GPU:0")); + // .device of shape is GPU since the op is executed on GPU + device_name = TFE_TensorHandleDeviceName(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(device_name.Contains("GPU:0")); - // .backing_device of shape is CPU since the tensor is backed by CPU - backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); - ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - ASSERT_TRUE(backing_device_name.Contains("CPU:0")); + // .backing_device of shape is CPU since the tensor is backed by CPU + backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(backing_device_name.Contains("CPU:0")); + } - TFE_DeleteOp(shape_op); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteTensorHandle(hgpu); } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs index df555bf0..a4d418ad 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs @@ -28,15 +28,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI var var_handle = CreateVariable(ctx, 12.0f, status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - var op = TFE_NewOp(ctx, "ReadVariableOp", status); - ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); int num_retvals = 1; var value_handle = new[] { IntPtr.Zero }; - TFE_Execute(op, value_handle, ref num_retvals, status); - TFE_DeleteOp(op); + using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) + { + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_Execute(op, value_handle, ref num_retvals, status); + } ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); ASSERT_EQ(1, num_retvals); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs index 4f9f6571..f11bab79 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs @@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI return th; } - IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) + SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) { using var status = TF_NewStatus(); @@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI return false; } - IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) + SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a) { using var status = TF_NewStatus(); @@ -79,39 +79,43 @@ namespace TensorFlowNET.UnitTest.NativeAPI unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) { - var op = TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); - TFE_OpSetAttrString(op, "container", "", 0); - TFE_OpSetAttrString(op, "shared_name", "", 0); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; var var_handle = new IntPtr[1]; int num_retvals = 1; - TFE_Execute(op, var_handle, ref num_retvals, status); - TFE_DeleteOp(op); + using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) + { + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + TFE_Execute(op, var_handle, ref num_retvals, status); + } + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; CHECK_EQ(1, num_retvals); // Assign 'value' to it. - op = TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle[0], status); + using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) + { + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle[0], status); - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); - tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); + tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); - var value_handle = c_api.TFE_NewTensorHandle(t, status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + var value_handle = c_api.TFE_NewTensorHandle(t, status); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - TFE_OpAddInput(op, value_handle, status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + TFE_OpAddInput(op, value_handle, status); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + + num_retvals = 0; + c_api.TFE_Execute(op, null, ref num_retvals, status); + } - num_retvals = 0; - c_api.TFE_Execute(op, null, ref num_retvals, status); - TFE_DeleteOp(op); if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; CHECK_EQ(0, num_retvals);