Browse Source

Implement SafeTensorHandleHandle as a wrapper for TFE_TensorHandle

tags/v0.20
Sam Harwell Haiping 5 years ago
parent
commit
263fbcf244
19 changed files with 369 additions and 174 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +10
    -4
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs
  6. +40
    -0
      src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs
  7. +0
    -19
      src/TensorFlowNET.Core/Eager/TFE_TensorHandle.cs
  8. +26
    -11
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  9. +6
    -2
      src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs
  10. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  11. +132
    -0
      src/TensorFlowNET.Core/Util/SafeHandleArrayMarshaler.cs
  12. +11
    -11
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
  13. +17
    -13
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
  14. +20
    -16
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
  15. +8
    -12
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  16. +1
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs
  17. +37
    -32
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
  18. +32
    -25
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs
  19. +18
    -18
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs

+ 1
- 1
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);
public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);

/// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)


+ 3
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow.Eager
{
for (int i = 0; i < inputs.Length; ++i)
{
IntPtr tensor_handle;
SafeTensorHandleHandle tensor_handle;
switch (inputs[i])
{
case EagerTensor et:
@@ -50,10 +50,10 @@ namespace Tensorflow.Eager
if (status.ok() && attrs != null)
SetOpAttrs(op, attrs);

var outputs = new IntPtr[num_outputs];
var outputs = new SafeTensorHandleHandle[num_outputs];
if (status.ok())
{
c_api.TFE_Execute(op, outputs, ref num_outputs, status.Handle);
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
status.Check(true);
}
return outputs.Select(x => new EagerTensor(x)).ToArray();


+ 3
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -154,8 +154,8 @@ namespace Tensorflow.Eager
num_retvals += (int)delta;
}

var retVals = new IntPtr[num_retvals];
c_api.TFE_Execute(op, retVals, ref num_retvals, status.Handle);
var retVals = new SafeTensorHandleHandle[num_retvals];
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
status.Check(true);

var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();
@@ -220,7 +220,7 @@ namespace Tensorflow.Eager
SafeOpHandle op,
Status status)
{
IntPtr input_handle;
SafeTensorHandleHandle input_handle;

// ConvertToTensor();
switch (inputs)


+ 10
- 4
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow.Eager
}

public EagerTensor(IntPtr handle) : base(IntPtr.Zero)
public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero)
{
EagerTensorHandle = handle;
Resolve();
@@ -58,14 +58,20 @@ namespace Tensorflow.Eager
}

public override IntPtr ToPointer()
=> EagerTensorHandle;
=> EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero;

protected override void DisposeManagedResources()
{
base.DisposeManagedResources();

//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
EagerTensorHandle.Dispose();
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
c_api.TF_DeleteTensor(_handle);
//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
c_api.TFE_DeleteTensorHandle(EagerTensorHandle);
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs View File

@@ -8,7 +8,8 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor
{
[Obsolete("Implicit conversion of EagerTensor to IntPtr is not supported.", error: true)]
public static implicit operator IntPtr(EagerTensor tensor)
=> tensor.EagerTensorHandle;
=> throw new NotSupportedException();
}
}

+ 40
- 0
src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow.Eager
{
public sealed class SafeTensorHandleHandle : SafeTensorflowHandle
{
private SafeTensorHandleHandle()
{
}

public SafeTensorHandleHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TFE_DeleteTensorHandle(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 0
- 19
src/TensorFlowNET.Core/Eager/TFE_TensorHandle.cs View File

@@ -1,19 +0,0 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow.Eager
{
[StructLayout(LayoutKind.Sequential)]
public struct TFE_TensorHandle
{
IntPtr _handle;

public static implicit operator IntPtr(TFE_TensorHandle tensor)
=> tensor._handle;

public override string ToString()
=> $"TFE_TensorHandle 0x{_handle.ToString("x16")}";
}
}

+ 26
- 11
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Runtime.InteropServices;
using Tensorflow.Device;
using Tensorflow.Eager;
using Tensorflow.Util;

namespace Tensorflow
{
@@ -66,7 +67,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status);
public static extern int TFE_OpAddInputList(SafeOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status);

/// <summary>
///
@@ -90,6 +91,20 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContext(IntPtr ctx);

public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
{
unsafe
{
num_retvals = retvals?.Length ?? 0;
var rawReturns = stackalloc IntPtr[num_retvals];
TFE_Execute(op, rawReturns, ref num_retvals, status);
for (var i = 0; i < num_retvals; i++)
{
retvals[i] = new SafeTensorHandleHandle(rawReturns[i]);
}
}
}

/// <summary>
/// Execute the operation defined by 'op' and return handles to computed
/// tensors in `retvals`.
@@ -99,7 +114,7 @@ namespace Tensorflow
/// <param name="num_retvals">int*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status);
private static unsafe extern void TFE_Execute(SafeOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status);

/// <summary>
///
@@ -198,7 +213,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status);
public static extern void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status);

/// <summary>
///
@@ -206,10 +221,10 @@ namespace Tensorflow
/// <param name="t">const tensorflow::Tensor&amp;</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status);
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_EagerTensorHandle(IntPtr t);
public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t);

/// <summary>
/// Sets the default execution mode (sync/async). Note that this can be
@@ -226,7 +241,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h);
public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h);

/// <summary>
/// This function will block till the operation that produces `h` has
@@ -237,7 +252,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);


/// <summary>
@@ -247,10 +262,10 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleNumDims(IntPtr h, SafeStatusHandle status);
public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleDim(IntPtr h, int dim, SafeStatusHandle status);
public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status);

/// <summary>
/// Returns the device of the operation that produced `h`. If `h` was produced by
@@ -263,7 +278,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleDeviceName(IntPtr h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);

/// <summary>
/// Returns the name of the device in whose memory `h` resides.
@@ -272,7 +287,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status);
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status);

/// <summary>
///


+ 6
- 2
src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
{
public class EagerTensorV2 : DisposableObject, ITensor
{
IntPtr EagerTensorHandle;
SafeTensorHandleHandle EagerTensorHandle;
public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status.Handle));

public EagerTensorV2(IntPtr handle)
@@ -64,10 +64,14 @@ namespace Tensorflow
}
}*/

protected override void DisposeManagedResources()
{
EagerTensorHandle.Dispose();
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TF_DeleteTensor(_handle);
c_api.TFE_DeleteTensorHandle(EagerTensorHandle);
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -23,6 +23,7 @@ using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Eager;
using Tensorflow.Framework;

namespace Tensorflow
@@ -94,7 +95,7 @@ namespace Tensorflow
/// <summary>
/// TFE_TensorHandle
/// </summary>
public IntPtr EagerTensorHandle { get; set; }
public SafeTensorHandleHandle EagerTensorHandle { get; set; }

/// <summary>
/// Returns the shape of a tensor.


+ 132
- 0
src/TensorFlowNET.Core/Util/SafeHandleArrayMarshaler.cs View File

@@ -0,0 +1,132 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;

namespace Tensorflow.Util
{
internal sealed class SafeHandleArrayMarshaler : ICustomMarshaler
{
private static readonly SafeHandleArrayMarshaler Instance = new SafeHandleArrayMarshaler();

private SafeHandleArrayMarshaler()
{
}

#pragma warning disable IDE0060 // Remove unused parameter (method is used implicitly)
public static ICustomMarshaler GetInstance(string cookie)
#pragma warning restore IDE0060 // Remove unused parameter
{
return Instance;
}

public int GetNativeDataSize()
{
return IntPtr.Size;
}

[HandleProcessCorruptedStateExceptions]
public IntPtr MarshalManagedToNative(object ManagedObj)
{
if (ManagedObj is null)
return IntPtr.Zero;

var array = (SafeHandle[])ManagedObj;
var native = IntPtr.Zero;
var marshaledArrayHandle = false;
try
{
native = Marshal.AllocHGlobal((array.Length + 1) * IntPtr.Size);
Marshal.WriteIntPtr(native, GCHandle.ToIntPtr(GCHandle.Alloc(array)));
marshaledArrayHandle = true;

var i = 0;
var success = false;
try
{
for (i = 0; i < array.Length; i++)
{
success = false;
var current = array[i];
var currentHandle = IntPtr.Zero;
if (current is object)
{
current.DangerousAddRef(ref success);
currentHandle = current.DangerousGetHandle();
}

Marshal.WriteIntPtr(native, ofs: (i + 1) * IntPtr.Size, currentHandle);
}

return IntPtr.Add(native, IntPtr.Size);
}
catch
{
// Clean up any handles which were leased prior to the exception
var total = success ? i + 1 : i;
for (var j = 0; j < total; j++)
{
var current = array[i];
if (current is object)
current.DangerousRelease();
}

throw;
}
}
catch
{
if (native != IntPtr.Zero)
{
if (marshaledArrayHandle)
GCHandle.FromIntPtr(Marshal.ReadIntPtr(native)).Free();

Marshal.FreeHGlobal(native);
}

throw;
}
}

public void CleanUpNativeData(IntPtr pNativeData)
{
if (pNativeData == IntPtr.Zero)
return;

var managedHandle = GCHandle.FromIntPtr(Marshal.ReadIntPtr(pNativeData, -IntPtr.Size));
var array = (SafeHandle[])managedHandle.Target;
managedHandle.Free();

for (var i = 0; i < array.Length; i++)
{
if (array[i] is object && !array[i].IsClosed)
array[i].DangerousRelease();
}
}

public object MarshalNativeToManaged(IntPtr pNativeData)
{
throw new NotSupportedException();
}

public void CleanUpManagedData(object ManagedObj)
{
throw new NotSupportedException();
}
}
}

+ 11
- 11
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest
protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
=> c_api.TF_SetAttrBool(desc, attrName, value);

protected TF_DataType TFE_TensorHandleDataType(IntPtr h)
protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h)
=> c_api.TFE_TensorHandleDataType(h);

protected int TFE_TensorHandleNumDims(IntPtr h, SafeStatusHandle status)
protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleNumDims(h, status);

protected TF_Code TF_GetCode(Status s)
@@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest
protected ulong TF_TensorByteSize(IntPtr t)
=> c_api.TF_TensorByteSize(t);

protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status)
protected void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status)
=> c_api.TFE_OpAddInput(op, h, status);

protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value)
@@ -98,11 +98,11 @@ namespace TensorFlowNET.UnitTest
protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status);

protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status)
protected SafeTensorHandleHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status)
=> c_api.TFE_NewTensorHandle(t, status);

protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status);
protected void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, out num_retvals, status);

protected SafeContextOptionsHandle TFE_NewContextOptions()
=> c_api.TFE_NewContextOptions();
@@ -113,7 +113,7 @@ namespace TensorFlowNET.UnitTest
protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetInputLength(op, input_name, status);

protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status)
protected int TFE_OpAddInputList(SafeOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);

protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status)
@@ -128,13 +128,13 @@ namespace TensorFlowNET.UnitTest
protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status)
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);

protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status)
protected IntPtr TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status)
=> c_api.TFE_TensorHandleResolve(h, status);

protected string TFE_TensorHandleDeviceName(IntPtr h, SafeStatusHandle status)
protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));

protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status)
protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));

protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
@@ -149,7 +149,7 @@ namespace TensorFlowNET.UnitTest
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.TF_DeviceListName(list, index, status);

protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);

protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status)


+ 17
- 13
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs View File

@@ -33,21 +33,25 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var m = TestMatrixTensorHandle();
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero };
using (var matmul = MatMulOp(ctx, m, m))
var retvals = new SafeTensorHandleHandle[2];
try
{
int num_retvals = 2;
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

TFE_DeleteTensorHandle(m);
using (var m = TestMatrixTensorHandle())
using (var matmul = MatMulOp(ctx, m, m))
{
int num_retvals;
c_api.TFE_Execute(matmul, retvals, out num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteTensorHandle(retvals[0]);
t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
finally
{
retvals[0]?.Dispose();
}
}

ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


+ 20
- 16
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -24,9 +24,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI
using var ctx = NewContext(status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var input1 = TestMatrixTensorHandle();
var input2 = TestMatrixTensorHandle();
var retvals = new IntPtr[2];
using var input1 = TestMatrixTensorHandle();
using var input2 = TestMatrixTensorHandle();

var retvals = new SafeTensorHandleHandle[2];
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
@@ -37,7 +38,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));

var inputs = new IntPtr[] { input1, input2 };
var inputs = new SafeTensorHandleHandle[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

@@ -47,21 +48,24 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_retvals = 2;
TFE_Execute(identityOp, retvals, ref num_retvals, status);
int num_retvals;
TFE_Execute(identityOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
try
{
// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
finally
{
retvals[0]?.Dispose();
retvals[1]?.Dispose();
}
}

TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(retvals[1]);
}
}
}

+ 8
- 12
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -1,5 +1,4 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;

@@ -24,10 +23,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI
using var ctx = NewContext(status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var condition = TestScalarTensorHandle(true);
var t1 = TestMatrixTensorHandle();
var t2 = TestAxisTensorHandle();
var retvals = new IntPtr[1];
using var condition = TestScalarTensorHandle(true);
using var t1 = TestMatrixTensorHandle();
using var t2 = TestAxisTensorHandle();
using (var assertOp = TFE_NewOp(ctx, "Assert", status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
@@ -44,15 +42,13 @@ namespace TensorFlowNET.UnitTest.NativeAPI
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);

int num_retvals = 1;
TFE_Execute(assertOp, retvals, ref num_retvals, status);
var retvals = new SafeTensorHandleHandle[1];
int num_retvals;
TFE_Execute(assertOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

TFE_DeleteTensorHandle(condition);
TFE_DeleteTensorHandle(t1);
TFE_DeleteTensorHandle(t2);
TFE_DeleteTensorHandle(retvals[0]);
retvals[0]?.Dispose();
}
}
}
}

+ 1
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
[TestMethod]
public unsafe void TensorHandle()
{
var h = TestMatrixTensorHandle();
using var h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h));

var status = c_api.TF_NewStatus();
@@ -28,7 +28,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(3.0f, data[2]);
EXPECT_EQ(4.0f, data[3]);
c_api.TF_DeleteTensor(t);
c_api.TFE_DeleteTensorHandle(h);
}
}
}

+ 37
- 32
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -24,47 +24,52 @@ namespace TensorFlowNET.UnitTest.NativeAPI
using var ctx = NewContext(status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var hcpu = TestMatrixTensorHandle();
var device_name = TFE_TensorHandleDeviceName(hcpu, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(device_name.Contains("CPU:0"));

var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(backing_device_name.Contains("CPU:0"));

// Disable the test if no GPU is present.
string gpu_device_name = "";
if(GetDeviceName(ctx, ref gpu_device_name, "GPU"))
using (var hcpu = TestMatrixTensorHandle())
{
var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
var device_name = TFE_TensorHandleDeviceName(hcpu, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(device_name.Contains("CPU:0"));

var retvals = new IntPtr[1];
using (var shape_op = ShapeOp(ctx, hgpu))
var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(backing_device_name.Contains("CPU:0"));

// Disable the test if no GPU is present.
string gpu_device_name = "";
if (GetDeviceName(ctx, ref gpu_device_name, "GPU"))
{
TFE_OpSetDevice(shape_op, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
int num_retvals = 1;
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status);
using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));

// .device of shape is GPU since the op is executed on GPU
device_name = TFE_TensorHandleDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(device_name.Contains("GPU:0"));
var retvals = new SafeTensorHandleHandle[1];
using (var shape_op = ShapeOp(ctx, hgpu))
{
TFE_OpSetDevice(shape_op, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
int num_retvals;
c_api.TFE_Execute(shape_op, retvals, out num_retvals, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));

// .backing_device of shape is CPU since the tensor is backed by CPU
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
}
try
{
// .device of shape is GPU since the op is executed on GPU
device_name = TFE_TensorHandleDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(device_name.Contains("GPU:0"));

TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
// .backing_device of shape is CPU since the tensor is backed by CPU
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
}
finally
{
retvals[0]?.Dispose();
}
}
}
}

TFE_DeleteTensorHandle(hcpu);
// not export api
using var executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);


+ 32
- 25
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs View File

@@ -25,35 +25,42 @@ namespace TensorFlowNET.UnitTest.NativeAPI
using var ctx = NewContext(status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var var_handle = CreateVariable(ctx, 12.0f, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_retvals = 1;
var value_handle = new[] { IntPtr.Zero };
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
using (var var_handle = CreateVariable(ctx, 12.0f, status))
{
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_Execute(op, value_handle, ref num_retvals, status);
}

ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_EQ(1, num_retvals);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0]));
EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status));
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var value = 0f; // new float[1];
var t = TFE_TensorHandleResolve(value_handle[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t));
tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float));
c_api.TF_DeleteTensor(t);
EXPECT_EQ(12.0f, value);
int num_retvals = 1;
var value_handle = new SafeTensorHandleHandle[1];
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
{
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_Execute(op, value_handle, out num_retvals, status);
}

try
{
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_EQ(1, num_retvals);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0]));
EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status));
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var value = 0f; // new float[1];
var t = TFE_TensorHandleResolve(value_handle[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t));
tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float));
c_api.TF_DeleteTensor(t);
EXPECT_EQ(12.0f, value);
}
finally
{
value_handle[0]?.Dispose();
}
}

TFE_DeleteTensorHandle(var_handle);
TFE_DeleteTensorHandle(value_handle[0]);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
}


+ 18
- 18
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs View File

@@ -12,7 +12,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
[TestClass]
public partial class CApiEagerTest : CApiTest
{
IntPtr TestMatrixTensorHandle()
SafeTensorHandleHandle TestMatrixTensorHandle()
{
var dims = new long[] { 2, 2 };
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return th;
}

SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b)
SafeOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b)
{
using var status = TF_NewStatus();

@@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return false;
}

SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a)
SafeOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a)
{
using var status = TF_NewStatus();

@@ -77,28 +77,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return op;
}

unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
{
var var_handle = new IntPtr[1];
int num_retvals = 1;
var var_handle = new SafeTensorHandleHandle[1];
int num_retvals;
using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
{
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_Execute(op, var_handle, ref num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
TFE_Execute(op, var_handle, out num_retvals, status);
}

if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
CHECK_EQ(1, num_retvals);

// Assign 'value' to it.
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
{
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle[0], status);

@@ -107,22 +107,22 @@ namespace TensorFlowNET.UnitTest.NativeAPI
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));

var value_handle = c_api.TFE_NewTensorHandle(t, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);

TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);

num_retvals = 0;
c_api.TFE_Execute(op, null, ref num_retvals, status);
c_api.TFE_Execute(op, null, out num_retvals, status);
}

if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
CHECK_EQ(0, num_retvals);

return var_handle[0];
}

IntPtr TestAxisTensorHandle()
SafeTensorHandleHandle TestAxisTensorHandle()
{
var dims = new long[] { 1 };
var data = new int[] { 1 };
@@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return th;
}

IntPtr TestScalarTensorHandle(bool value)
SafeTensorHandleHandle TestScalarTensorHandle(bool value)
{
var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
@@ -147,7 +147,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return th;
}

IntPtr TestScalarTensorHandle(float value)
SafeTensorHandleHandle TestScalarTensorHandle(float value)
{
var data = new [] { value };
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));


Loading…
Cancel
Save