Browse Source

Implement SafeOpHandle as a wrapper for TFE_Op

tags/v0.20
Sam Harwell Haiping 5 years ago
parent
commit
fae2c29599
11 changed files with 177 additions and 153 deletions
  1. +7
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  2. +40
    -0
      src/TensorFlowNET.Core/Eager/SafeOpHandle.cs
  3. +0
    -23
      src/TensorFlowNET.Core/Eager/TFE_Op.cs
  4. +14
    -14
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  5. +10
    -13
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
  6. +8
    -6
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
  7. +26
    -25
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
  8. +19
    -18
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  9. +16
    -15
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
  10. +8
    -7
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs
  11. +29
    -25
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs

+ 7
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -145,7 +145,7 @@ namespace Tensorflow.Eager
return flat_result; return flat_result;
} }


TFE_Op GetOp(Context ctx, string op_or_function_name, Status status)
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status)
{ {
if (thread_local_eager_operation_map.find(ctx, out var op)) if (thread_local_eager_operation_map.find(ctx, out var op))
c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle); c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle);
@@ -159,7 +159,7 @@ namespace Tensorflow.Eager
return op; return op;
} }


static UnorderedMap<Context, TFE_Op> thread_local_eager_operation_map = new UnorderedMap<Context, TFE_Op>();
static UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>();


bool HasAccumulator() bool HasAccumulator()
{ {
@@ -192,7 +192,7 @@ namespace Tensorflow.Eager
ArgDef input_arg, ArgDef input_arg,
List<object> flattened_attrs, List<object> flattened_attrs,
List<Tensor> flattened_inputs, List<Tensor> flattened_inputs,
IntPtr op,
SafeOpHandle op,
Status status) Status status)
{ {
IntPtr input_handle; IntPtr input_handle;
@@ -224,7 +224,7 @@ namespace Tensorflow.Eager
return true; return true;
} }


public void SetOpAttrs(TFE_Op op, params object[] attrs)
public void SetOpAttrs(SafeOpHandle op, params object[] attrs)
{ {
var status = tf.status; var status = tf.status;
var len = attrs.Length; var len = attrs.Length;
@@ -257,7 +257,7 @@ namespace Tensorflow.Eager
/// <param name="attr_value"></param> /// <param name="attr_value"></param>
/// <param name="attr_list_sizes"></param> /// <param name="attr_list_sizes"></param>
/// <param name="status"></param> /// <param name="status"></param>
void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr,
void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr,
string attr_name, object attr_value, string attr_name, object attr_value,
Dictionary<string, long> attr_list_sizes, Dictionary<string, long> attr_list_sizes,
Status status) Status status)
@@ -290,7 +290,7 @@ namespace Tensorflow.Eager
} }
} }


bool SetOpAttrList(Context ctx, IntPtr op,
bool SetOpAttrList(Context ctx, SafeOpHandle op,
string key, object value, TF_AttrType type, string key, object value, TF_AttrType type,
Dictionary<string, long> attr_list_sizes, Dictionary<string, long> attr_list_sizes,
Status status) Status status)
@@ -298,7 +298,7 @@ namespace Tensorflow.Eager
return false; return false;
} }


bool SetOpAttrScalar(Context ctx, IntPtr op,
bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
string key, object value, TF_AttrType type, string key, object value, TF_AttrType type,
Dictionary<string, long> attr_list_sizes, Dictionary<string, long> attr_list_sizes,
Status status) Status status)


+ 40
- 0
src/TensorFlowNET.Core/Eager/SafeOpHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow.Eager
{
public sealed class SafeOpHandle : SafeTensorflowHandle
{
private SafeOpHandle()
{
}

public SafeOpHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TFE_DeleteOp(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 0
- 23
src/TensorFlowNET.Core/Eager/TFE_Op.cs View File

@@ -1,23 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public struct TFE_Op
{
IntPtr _handle;

public TFE_Op(IntPtr handle)
=> _handle = handle;

public static implicit operator TFE_Op(IntPtr handle)
=> new TFE_Op(handle);

public static implicit operator IntPtr(TFE_Op tensor)
=> tensor._handle;

public override string ToString()
=> $"TFE_Op 0x{_handle.ToString("x16")}";
}
}

+ 14
- 14
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status);
public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status);
@@ -43,7 +43,7 @@ namespace Tensorflow
/// <param name="input_name">const char*</param> /// <param name="input_name">const char*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status);
public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status);


/// <summary> /// <summary>
/// Returns the length (number of tensors) of the output argument `output_name` /// Returns the length (number of tensors) of the output argument `output_name`
@@ -54,7 +54,7 @@ namespace Tensorflow
/// <param name="status"></param> /// <param name="status"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status);
public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -65,7 +65,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status);
public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -98,7 +98,7 @@ namespace Tensorflow
/// <param name="num_retvals">int*</param> /// <param name="num_retvals">int*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status);
public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -108,7 +108,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status);
public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status);


/// <summary> /// <summary>
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
@@ -124,7 +124,7 @@ namespace Tensorflow
/// <param name="raw_device_name">const char*</param> /// <param name="raw_device_name">const char*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status);
public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -140,10 +140,10 @@ namespace Tensorflow
/// <param name="attr_name">const char*</param> /// <param name="attr_name">const char*</param>
/// <param name="value">TF_DataType</param> /// <param name="value">TF_DataType</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value);
public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value);
public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value);


/// <summary> /// <summary>
/// ///
@@ -154,10 +154,10 @@ namespace Tensorflow
/// <param name="num_dims">const int</param> /// <param name="num_dims">const int</param>
/// <param name="out_status">TF_Status*</param> /// <param name="out_status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status);
public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrBool(IntPtr op, string attr_name, bool value);
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value);


/// <summary> /// <summary>
/// ///
@@ -167,7 +167,7 @@ namespace Tensorflow
/// <param name="value">const void*</param> /// <param name="value">const void*</param>
/// <param name="length">size_t</param> /// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length);
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length);


/// <summary> /// <summary>
/// ///
@@ -176,7 +176,7 @@ namespace Tensorflow
/// <param name="device_name"></param> /// <param name="device_name"></param>
/// <param name="status"></param> /// <param name="status"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetDevice(TFE_Op op, string device_name, SafeStatusHandle status);
public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///
@@ -185,7 +185,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param> /// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status);
public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status);


/// <summary> /// <summary>
/// ///


+ 10
- 13
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -82,25 +82,25 @@ namespace TensorFlowNET.UnitTest
protected ulong TF_TensorByteSize(IntPtr t) protected ulong TF_TensorByteSize(IntPtr t)
=> c_api.TF_TensorByteSize(t); => c_api.TF_TensorByteSize(t);


protected void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status)
protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status)
=> c_api.TFE_OpAddInput(op, h, status); => c_api.TFE_OpAddInput(op, h, status);


protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value)
protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value)
=> c_api.TFE_OpSetAttrType(op, attr_name, value); => c_api.TFE_OpSetAttrType(op, attr_name, value);


protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status)
protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status)
=> c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status);


protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length)
protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length)
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); => c_api.TFE_OpSetAttrString(op, attr_name, value, length);


protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); => c_api.TFE_NewOp(ctx, op_or_function_name, status);


protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status)
=> c_api.TFE_NewTensorHandle(t, status); => c_api.TFE_NewTensorHandle(t, status);


protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status)
protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status); => c_api.TFE_Execute(op, retvals, ref num_retvals, status);


protected SafeContextOptionsHandle TFE_NewContextOptions() protected SafeContextOptionsHandle TFE_NewContextOptions()
@@ -109,21 +109,18 @@ namespace TensorFlowNET.UnitTest
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status)
=> c_api.TFE_NewContext(opts, status); => c_api.TFE_NewContext(opts, status);


protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status)
protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetInputLength(op, input_name, status); => c_api.TFE_OpGetInputLength(op, input_name, status);


protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status)
protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);


protected int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status)
protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetOutputLength(op, input_name, status); => c_api.TFE_OpGetOutputLength(op, input_name, status);


protected void TFE_DeleteTensorHandle(IntPtr h) protected void TFE_DeleteTensorHandle(IntPtr h)
=> c_api.TFE_DeleteTensorHandle(h); => c_api.TFE_DeleteTensorHandle(h);


protected void TFE_DeleteOp(IntPtr op)
=> c_api.TFE_DeleteOp(op);

protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx)
=> c_api.TFE_ContextGetExecutorForThread(ctx); => c_api.TFE_ContextGetExecutorForThread(ctx);


@@ -154,7 +151,7 @@ namespace TensorFlowNET.UnitTest
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);


protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status)
protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status)
=> c_api.TFE_OpSetDevice(op, device_name, status); => c_api.TFE_OpSetDevice(op, device_name, status);
} }
} }

+ 8
- 6
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs View File

@@ -34,13 +34,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


var m = TestMatrixTensorHandle(); var m = TestMatrixTensorHandle();
var matmul = MatMulOp(ctx, m, m);
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero };
int num_retvals = 2;
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteOp(matmul);
using (var matmul = MatMulOp(ctx, m, m))
{
int num_retvals = 2;
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(m);


t = TFE_TensorHandleResolve(retvals[0], status); t = TFE_TensorHandleResolve(retvals[0], status);


+ 26
- 25
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -26,37 +26,38 @@ namespace TensorFlowNET.UnitTest.NativeAPI


var input1 = TestMatrixTensorHandle(); var input1 = TestMatrixTensorHandle();
var input2 = TestMatrixTensorHandle(); var input2 = TestMatrixTensorHandle();
var identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var retvals = new IntPtr[2];
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


// Try to retrieve lengths before building the attributes (should fail)
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));
// Try to retrieve lengths before building the attributes (should fail)
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));


var inputs = new IntPtr[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var inputs = new IntPtr[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


// Try to retrieve lengths before executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
// Try to retrieve lengths before executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


var retvals = new IntPtr[2];
int num_retvals = 2;
TFE_Execute(identityOp, retvals, ref num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
int num_retvals = 2;
TFE_Execute(identityOp, retvals, ref num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}


TFE_DeleteOp(identityOp);
TFE_DeleteTensorHandle(input1); TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2); TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteTensorHandle(retvals[0]);


+ 19
- 18
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -27,27 +27,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI
var condition = TestScalarTensorHandle(true); var condition = TestScalarTensorHandle(true);
var t1 = TestMatrixTensorHandle(); var t1 = TestMatrixTensorHandle();
var t2 = TestAxisTensorHandle(); var t2 = TestAxisTensorHandle();
var assertOp = TFE_NewOp(ctx, "Assert", status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpAddInput(assertOp, condition, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var data = new[] { condition, t1, t2 };
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var retvals = new IntPtr[1];
using (var assertOp = TFE_NewOp(ctx, "Assert", status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpAddInput(assertOp, condition, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var data = new[] { condition, t1, t2 };
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


/*var attr_values = Graph.TFE_GetOpDef("Assert").Attr;
var attr_found = attr_values.First(x => x.Name == "T");
EXPECT_NE(attr_found, attr_values.Last());*/
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL");
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);
/*var attr_values = Graph.TFE_GetOpDef("Assert").Attr;
var attr_found = attr_values.First(x => x.Name == "T");
EXPECT_NE(attr_found, attr_values.Last());*/
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL");
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);


var retvals = new IntPtr[1];
int num_retvals = 1;
TFE_Execute(assertOp, retvals, ref num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
int num_retvals = 1;
TFE_Execute(assertOp, retvals, ref num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}


TFE_DeleteOp(assertOp);
TFE_DeleteTensorHandle(condition); TFE_DeleteTensorHandle(condition);
TFE_DeleteTensorHandle(t1); TFE_DeleteTensorHandle(t1);
TFE_DeleteTensorHandle(t2); TFE_DeleteTensorHandle(t2);


+ 16
- 15
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -40,25 +40,26 @@ namespace TensorFlowNET.UnitTest.NativeAPI
var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));


var shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
var retvals = new IntPtr[1]; var retvals = new IntPtr[1];
int num_retvals = 1;
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
using (var shape_op = ShapeOp(ctx, hgpu))
{
TFE_OpSetDevice(shape_op, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
int num_retvals = 1;
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));


// .device of shape is GPU since the op is executed on GPU
device_name = TFE_TensorHandleDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(device_name.Contains("GPU:0"));
// .device of shape is GPU since the op is executed on GPU
device_name = TFE_TensorHandleDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(device_name.Contains("GPU:0"));


// .backing_device of shape is CPU since the tensor is backed by CPU
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
// .backing_device of shape is CPU since the tensor is backed by CPU
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
}


TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu); TFE_DeleteTensorHandle(hgpu);
} }


+ 8
- 7
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs View File

@@ -28,15 +28,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI
var var_handle = CreateVariable(ctx, 12.0f, status); var var_handle = CreateVariable(ctx, 12.0f, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


var op = TFE_NewOp(ctx, "ReadVariableOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
int num_retvals = 1; int num_retvals = 1;
var value_handle = new[] { IntPtr.Zero }; var value_handle = new[] { IntPtr.Zero };
TFE_Execute(op, value_handle, ref num_retvals, status);
TFE_DeleteOp(op);
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
{
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_Execute(op, value_handle, ref num_retvals, status);
}


ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
ASSERT_EQ(1, num_retvals); ASSERT_EQ(1, num_retvals);


+ 29
- 25
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs View File

@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return th; return th;
} }


IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b)
SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b)
{ {
using var status = TF_NewStatus(); using var status = TF_NewStatus();


@@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return false; return false;
} }


IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a)
SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a)
{ {
using var status = TF_NewStatus(); using var status = TF_NewStatus();


@@ -79,39 +79,43 @@ namespace TensorFlowNET.UnitTest.NativeAPI


unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
{ {
var op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
var var_handle = new IntPtr[1]; var var_handle = new IntPtr[1];
int num_retvals = 1; int num_retvals = 1;
TFE_Execute(op, var_handle, ref num_retvals, status);
TFE_DeleteOp(op);
using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
{
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_Execute(op, var_handle, ref num_retvals, status);
}

if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
CHECK_EQ(1, num_retvals); CHECK_EQ(1, num_retvals);


// Assign 'value' to it. // Assign 'value' to it.
op = TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle[0], status);
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
{
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle[0], status);


// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float));
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float));
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));


var value_handle = c_api.TFE_NewTensorHandle(t, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
var value_handle = c_api.TFE_NewTensorHandle(t, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;


TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;

num_retvals = 0;
c_api.TFE_Execute(op, null, ref num_retvals, status);
}


num_retvals = 0;
c_api.TFE_Execute(op, null, ref num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
CHECK_EQ(0, num_retvals); CHECK_EQ(0, num_retvals);




Loading…
Cancel
Save