Browse Source

Implement SafeExecutorHandle as a wrapper for TFE_Executor

pull/578/head
Sam Harwell 5 years ago
parent
commit
3e917b60e8
5 changed files with 47 additions and 35 deletions
  1. +40
    -0
      src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs
  2. +0
    -23
      src/TensorFlowNET.Core/Eager/TFE_Executor.cs
  3. +4
    -5
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +2
    -5
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
  5. +1
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs

+ 40
- 0
src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow.Eager
{
public sealed class SafeExecutorHandle : SafeTensorflowHandle
{
private SafeExecutorHandle()
{
}

public SafeExecutorHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TFE_DeleteExecutor(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 0
- 23
src/TensorFlowNET.Core/Eager/TFE_Executor.cs View File

@@ -1,23 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public struct TFE_Executor
{
IntPtr _handle;

public TFE_Executor(IntPtr handle)
=> _handle = handle;

public static implicit operator TFE_Executor(IntPtr handle)
=> new TFE_Executor(handle);

public static implicit operator IntPtr(TFE_Executor tensor)
=> tensor._handle;

public override string ToString()
=> $"TFE_Executor {_handle}";
}
}

+ 4
- 5
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -2,7 +2,6 @@
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using Tensorflow.Device; using Tensorflow.Device;
using Tensorflow.Eager; using Tensorflow.Eager;
using TFE_Executor = System.IntPtr;


namespace Tensorflow namespace Tensorflow
{ {
@@ -299,7 +298,7 @@ namespace Tensorflow
/// <param name="is_async"></param> /// <param name="is_async"></param>
/// <returns>TFE_Executor*</returns> /// <returns>TFE_Executor*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_NewExecutor(bool is_async);
public static extern SafeExecutorHandle TFE_NewExecutor(bool is_async);


/// <summary> /// <summary>
/// Deletes the eager Executor without waiting for enqueued nodes. Please call /// Deletes the eager Executor without waiting for enqueued nodes. Please call
@@ -322,7 +321,7 @@ namespace Tensorflow
/// <param name="executor">TFE_Executor*</param> /// <param name="executor">TFE_Executor*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor executor, SafeStatusHandle status);
public static extern void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status);


/// <summary> /// <summary>
/// Sets a custom Executor for current thread. All nodes created by this thread /// Sets a custom Executor for current thread. All nodes created by this thread
@@ -331,7 +330,7 @@ namespace Tensorflow
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="executor"></param> /// <param name="executor"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor);
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, SafeExecutorHandle executor);


/// <summary> /// <summary>
/// Returns the Executor for current thread. /// Returns the Executor for current thread.
@@ -339,7 +338,7 @@ namespace Tensorflow
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns>TFE_Executor*</returns> /// <returns>TFE_Executor*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx);
public static extern SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);


+ 2
- 5
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -124,13 +124,10 @@ namespace TensorFlowNET.UnitTest
protected void TFE_DeleteOp(IntPtr op) protected void TFE_DeleteOp(IntPtr op)
=> c_api.TFE_DeleteOp(op); => c_api.TFE_DeleteOp(op);


protected void TFE_DeleteExecutor(IntPtr executor)
=> c_api.TFE_DeleteExecutor(executor);

protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx)
protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx)
=> c_api.TFE_ContextGetExecutorForThread(ctx); => c_api.TFE_ContextGetExecutorForThread(ctx);


protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status)
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status)
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);


protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status)


+ 1
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -65,10 +65,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI


TFE_DeleteTensorHandle(hcpu); TFE_DeleteTensorHandle(hcpu);
// not export api // not export api
var executor = TFE_ContextGetExecutorForThread(ctx);
using var executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status); TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteExecutor(executor);
} }
} }
} }

Loading…
Cancel
Save