@@ -0,0 +1,40 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public sealed class SafeExecutorHandle : SafeTensorflowHandle | |||||
{ | |||||
private SafeExecutorHandle() | |||||
{ | |||||
} | |||||
public SafeExecutorHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TFE_DeleteExecutor(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -1,23 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public struct TFE_Executor | |||||
{ | |||||
IntPtr _handle; | |||||
public TFE_Executor(IntPtr handle) | |||||
=> _handle = handle; | |||||
public static implicit operator TFE_Executor(IntPtr handle) | |||||
=> new TFE_Executor(handle); | |||||
public static implicit operator IntPtr(TFE_Executor tensor) | |||||
=> tensor._handle; | |||||
public override string ToString() | |||||
=> $"TFE_Executor {_handle}"; | |||||
} | |||||
} |
@@ -2,7 +2,6 @@ | |||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using Tensorflow.Device; | using Tensorflow.Device; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using TFE_Executor = System.IntPtr; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -299,7 +298,7 @@ namespace Tensorflow | |||||
/// <param name="is_async"></param> | /// <param name="is_async"></param> | ||||
/// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Executor TFE_NewExecutor(bool is_async); | |||||
public static extern SafeExecutorHandle TFE_NewExecutor(bool is_async); | |||||
/// <summary> | /// <summary> | ||||
/// Deletes the eager Executor without waiting for enqueued nodes. Please call | /// Deletes the eager Executor without waiting for enqueued nodes. Please call | ||||
@@ -322,7 +321,7 @@ namespace Tensorflow | |||||
/// <param name="executor">TFE_Executor*</param> | /// <param name="executor">TFE_Executor*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor executor, SafeStatusHandle status); | |||||
public static extern void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Sets a custom Executor for current thread. All nodes created by this thread | /// Sets a custom Executor for current thread. All nodes created by this thread | ||||
@@ -331,7 +330,7 @@ namespace Tensorflow | |||||
/// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
/// <param name="executor"></param> | /// <param name="executor"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); | |||||
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, SafeExecutorHandle executor); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the Executor for current thread. | /// Returns the Executor for current thread. | ||||
@@ -339,7 +338,7 @@ namespace Tensorflow | |||||
/// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
/// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||||
public static extern SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | ||||
@@ -124,13 +124,10 @@ namespace TensorFlowNET.UnitTest | |||||
protected void TFE_DeleteOp(IntPtr op) | protected void TFE_DeleteOp(IntPtr op) | ||||
=> c_api.TFE_DeleteOp(op); | => c_api.TFE_DeleteOp(op); | ||||
protected void TFE_DeleteExecutor(IntPtr executor) | |||||
=> c_api.TFE_DeleteExecutor(executor); | |||||
protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||||
protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||||
=> c_api.TFE_ContextGetExecutorForThread(ctx); | => c_api.TFE_ContextGetExecutorForThread(ctx); | ||||
protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | |||||
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) | |||||
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) | protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) | ||||
@@ -65,10 +65,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_DeleteTensorHandle(hcpu); | TFE_DeleteTensorHandle(hcpu); | ||||
// not export api | // not export api | ||||
var executor = TFE_ContextGetExecutorForThread(ctx); | |||||
using var executor = TFE_ContextGetExecutorForThread(ctx); | |||||
TFE_ExecutorWaitForAllPendingNodes(executor, status); | TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteExecutor(executor); | |||||
} | } | ||||
} | } | ||||
} | } |