Browse Source

Implement SafeContextHandle as a wrapper for TFE_Context

tags/v0.20
Sam Harwell 5 years ago
parent
commit
ab6d624ea8
14 changed files with 103 additions and 92 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +8
    -15
      src/TensorFlowNET.Core/Eager/Context.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerOperation.cs
  4. +40
    -0
      src/TensorFlowNET.Core/Eager/SafeContextHandle.cs
  5. +0
    -23
      src/TensorFlowNET.Core/Eager/TFE_Context.cs
  6. +10
    -10
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  7. +6
    -8
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
  8. +7
    -6
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
  9. +20
    -16
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
  10. +1
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
  11. +1
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  12. +1
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
  13. +1
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs
  14. +5
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs

+ 2
- 1
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using Tensorflow.Eager;

namespace Tensorflow
{
@@ -64,7 +65,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);

/// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)


+ 8
- 15
src/TensorFlowNET.Core/Eager/Context.cs View File

@@ -2,7 +2,7 @@

namespace Tensorflow.Eager
{
public class Context : DisposableObject
public sealed class Context : IDisposable
{
public const int GRAPH_MODE = 0;
public const int EAGER_MODE = 1;
@@ -12,9 +12,11 @@ namespace Tensorflow.Eager
public string scope_name = "";
bool _initialized = false;

public SafeContextHandle Handle { get; }

public Context(ContextOptions opts, Status status)
{
_handle = c_api.TFE_NewContext(opts, status.Handle);
Handle = c_api.TFE_NewContext(opts, status.Handle);
status.Check(true);
}

@@ -29,16 +31,10 @@ namespace Tensorflow.Eager
}

public void start_step()
=> c_api.TFE_ContextStartStep(_handle);
=> c_api.TFE_ContextStartStep(Handle);

public void end_step()
=> c_api.TFE_ContextEndStep(_handle);

/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected sealed override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TFE_DeleteContext(_handle);
=> c_api.TFE_ContextEndStep(Handle);

public bool executing_eagerly()
=> default_execution_mode == EAGER_MODE;
@@ -48,10 +44,7 @@ namespace Tensorflow.Eager
name :
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";

public static implicit operator IntPtr(Context ctx)
=> ctx._handle;

public static implicit operator TFE_Context(Context ctx)
=> new TFE_Context(ctx._handle);
public void Dispose()
=> Handle.Dispose();
}
}

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerOperation.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Eager
{
object value = null;
byte isList = 0;
var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status.Handle);
var attrType = c_api.TFE_OpNameGetAttrType(tf.context.Handle, Name, attr_name, ref isList, tf.status.Handle);
switch (attrType)
{
case TF_AttrType.TF_ATTR_BOOL:


+ 40
- 0
src/TensorFlowNET.Core/Eager/SafeContextHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow.Eager
{
public sealed class SafeContextHandle : SafeTensorflowHandle
{
public SafeContextHandle()
{
}

public SafeContextHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TFE_DeleteContext(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 0
- 23
src/TensorFlowNET.Core/Eager/TFE_Context.cs View File

@@ -1,23 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public struct TFE_Context
{
IntPtr _handle;

public TFE_Context(IntPtr handle)
=> _handle = handle;

public static implicit operator TFE_Context(IntPtr handle)
=> new TFE_Context(handle);

public static implicit operator IntPtr(TFE_Context tensor)
=> tensor._handle;

public override string ToString()
=> $"TFE_Context {_handle}";
}
}

+ 10
- 10
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -73,7 +73,7 @@ namespace Tensorflow
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status);
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status);

/// <summary>
/// Returns the length (number of tensors) of the input argument `input_name`
@@ -114,13 +114,13 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TFE_Context*</returns>
[DllImport(TensorFlowLibName)]
public static extern TFE_Context TFE_NewContext(IntPtr opts, SafeStatusHandle status);
public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern TFE_Context TFE_ContextStartStep(IntPtr ctx);
public static extern void TFE_ContextStartStep(SafeContextHandle ctx);

[DllImport(TensorFlowLibName)]
public static extern TFE_Context TFE_ContextEndStep(IntPtr ctx);
public static extern void TFE_ContextEndStep(SafeContextHandle ctx);

/// <summary>
///
@@ -148,7 +148,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status);
public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status);

/// <summary>
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
@@ -317,7 +317,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status);
public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);

/// <summary>
///
@@ -379,7 +379,7 @@ namespace Tensorflow
/// <param name="ctx"></param>
/// <param name="executor"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor);
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor);

/// <summary>
/// Returns the Executor for current thread.
@@ -387,7 +387,7 @@ namespace Tensorflow
/// <param name="ctx"></param>
/// <returns>TFE_Executor*</returns>
[DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx);
public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx);

/// <summary>
///
@@ -402,7 +402,7 @@ namespace Tensorflow
/// <param name="status"></param>
/// <returns>EagerTensorHandle</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_FastPathExecute(IntPtr ctx,
public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx,
string device_name,
string op_name,
string name,
@@ -416,7 +416,7 @@ namespace Tensorflow
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op);

[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_QuickExecute(IntPtr ctx,
public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx,
string device_name,
string op_name,
IntPtr[] inputs,


+ 6
- 8
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;

namespace TensorFlowNET.UnitTest
@@ -92,7 +93,7 @@ namespace TensorFlowNET.UnitTest
protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length)
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length);

protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status)
protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status);

protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status)
@@ -104,10 +105,7 @@ namespace TensorFlowNET.UnitTest
protected IntPtr TFE_NewContextOptions()
=> c_api.TFE_NewContextOptions();

protected void TFE_DeleteContext(IntPtr t)
=> c_api.TFE_DeleteContext(t);

protected IntPtr TFE_NewContext(IntPtr opts, SafeStatusHandle status)
protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status)
=> c_api.TFE_NewContext(opts, status);

protected void TFE_DeleteContextOptions(IntPtr opts)
@@ -131,7 +129,7 @@ namespace TensorFlowNET.UnitTest
protected void TFE_DeleteExecutor(IntPtr executor)
=> c_api.TFE_DeleteExecutor(executor);

protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx)
protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx)
=> c_api.TFE_ContextGetExecutorForThread(ctx);

protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status)
@@ -146,7 +144,7 @@ namespace TensorFlowNET.UnitTest
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));

protected IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status)
protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
=> c_api.TFE_ContextListDevices(ctx, status);

protected int TF_DeviceListCount(IntPtr list)
@@ -161,7 +159,7 @@ namespace TensorFlowNET.UnitTest
protected void TF_DeleteDeviceList(IntPtr list)
=> c_api.TF_DeleteDeviceList(list);

protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status)
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);

protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status)


+ 7
- 6
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs View File

@@ -1,7 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
{
@@ -15,14 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{
using var status = c_api.TF_NewStatus();
var opts = c_api.TFE_NewContextOptions();
var ctx = c_api.TFE_NewContext(opts, status);

c_api.TFE_DeleteContextOptions(opts);
IntPtr devices;
using (var ctx = c_api.TFE_NewContext(opts, status))
{
c_api.TFE_DeleteContextOptions(opts);

var devices = c_api.TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
devices = c_api.TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

c_api.TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_devices = c_api.TF_DeviceListCount(devices);


+ 20
- 16
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs View File

@@ -21,24 +21,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI
using var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async));
var ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var m = TestMatrixTensorHandle();
var matmul = MatMulOp(ctx, m, m);
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero };
int num_retvals = 2;
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
IntPtr t;
using (var ctx = TFE_NewContext(opts, status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var m = TestMatrixTensorHandle();
var matmul = MatMulOp(ctx, m, m);
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero };
int num_retvals = 2;
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);

t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteTensorHandle(retvals[0]);
}

var t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var product = new float[4];
EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t));


+ 1
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{
using var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
var ctx = TFE_NewContext(opts, status);
using var ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

@@ -57,7 +57,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(retvals[1]);
TFE_DeleteContext(ctx);
}
}
}

+ 1
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{
using var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
var ctx = TFE_NewContext(opts, status);
using var ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

@@ -50,7 +50,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
TFE_DeleteTensorHandle(t1);
TFE_DeleteTensorHandle(t2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
}
}
}

+ 1
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{
var status = c_api.TF_NewStatus();
var opts = TFE_NewContextOptions();
var ctx = TFE_NewContext(opts, status);
using var ctx = TFE_NewContext(opts, status);
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

@@ -65,7 +65,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
}
}

+ 1
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs View File

@@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{
using var status = c_api.TF_NewStatus();
var opts = TFE_NewContextOptions();
var ctx = TFE_NewContext(opts, status);
using var ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

@@ -47,7 +47,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI

TFE_DeleteTensorHandle(var_handle);
TFE_DeleteTensorHandle(value_handle[0]);
TFE_DeleteContext(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
}


+ 5
- 4
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
@@ -25,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return th;
}

IntPtr MatMulOp(IntPtr ctx, IntPtr a, IntPtr b)
IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b)
{
using var status = TF_NewStatus();

@@ -40,7 +41,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return op;
}

bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type)
bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type)
{
var status = TF_NewStatus();
var devices = TFE_ContextListDevices(ctx, status);
@@ -65,7 +66,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return false;
}

IntPtr ShapeOp(IntPtr ctx, IntPtr a)
IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a)
{
using var status = TF_NewStatus();

@@ -78,7 +79,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return op;
}

unsafe IntPtr CreateVariable(IntPtr ctx, float value, SafeStatusHandle status)
unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
{
var op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;


Loading…
Cancel
Save