diff --git a/src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs b/src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs new file mode 100644 index 00000000..cf6601e7 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeExecutorHandle : SafeTensorflowHandle + { + private SafeExecutorHandle() + { + } + + public SafeExecutorHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteExecutor(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/TFE_Executor.cs b/src/TensorFlowNET.Core/Eager/TFE_Executor.cs deleted file mode 100644 index ed88dd30..00000000 --- a/src/TensorFlowNET.Core/Eager/TFE_Executor.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow.Eager -{ - public struct TFE_Executor - { - IntPtr _handle; - - public TFE_Executor(IntPtr handle) - => _handle = handle; - - public static implicit operator TFE_Executor(IntPtr handle) - => new TFE_Executor(handle); - - public static implicit operator IntPtr(TFE_Executor tensor) - => tensor._handle; - - public override string ToString() - => $"TFE_Executor {_handle}"; - } -} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index f49b1f05..f78ae02a 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -2,7 +2,6 @@ using System.Runtime.InteropServices; using Tensorflow.Device; using Tensorflow.Eager; -using TFE_Executor = System.IntPtr; namespace Tensorflow { @@ -299,7 +298,7 @@ namespace Tensorflow /// /// TFE_Executor* [DllImport(TensorFlowLibName)] - public static extern TFE_Executor TFE_NewExecutor(bool is_async); + public static extern SafeExecutorHandle TFE_NewExecutor(bool is_async); /// /// Deletes the eager Executor without waiting for enqueued nodes. Please call @@ -322,7 +321,7 @@ namespace Tensorflow /// TFE_Executor* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor executor, SafeStatusHandle status); + public static extern void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status); /// /// Sets a custom Executor for current thread. All nodes created by this thread @@ -331,7 +330,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); + public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, SafeExecutorHandle executor); /// /// Returns the Executor for current thread. @@ -339,7 +338,7 @@ namespace Tensorflow /// /// TFE_Executor* [DllImport(TensorFlowLibName)] - public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); + public static extern SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx); [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs index 3353f531..5454d4ff 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs @@ -124,13 +124,10 @@ namespace TensorFlowNET.UnitTest protected void TFE_DeleteOp(IntPtr op) => c_api.TFE_DeleteOp(op); - protected void TFE_DeleteExecutor(IntPtr executor) - => c_api.TFE_DeleteExecutor(executor); - - protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) + protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) => c_api.TFE_ContextGetExecutorForThread(ctx); - protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) + protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs index 812883c3..6cee132f 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs @@ -65,10 +65,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI TFE_DeleteTensorHandle(hcpu); // not export api - var executor = TFE_ContextGetExecutorForThread(ctx); + using var executor = TFE_ContextGetExecutorForThread(ctx); TFE_ExecutorWaitForAllPendingNodes(executor, status); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); - TFE_DeleteExecutor(executor); } } }