@@ -145,7 +145,7 @@ namespace Tensorflow.Eager | |||||
return flat_result; | return flat_result; | ||||
} | } | ||||
TFE_Op GetOp(Context ctx, string op_or_function_name, Status status) | |||||
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||||
{ | { | ||||
if (thread_local_eager_operation_map.find(ctx, out var op)) | if (thread_local_eager_operation_map.find(ctx, out var op)) | ||||
c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle); | c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle); | ||||
@@ -159,7 +159,7 @@ namespace Tensorflow.Eager | |||||
return op; | return op; | ||||
} | } | ||||
static UnorderedMap<Context, TFE_Op> thread_local_eager_operation_map = new UnorderedMap<Context, TFE_Op>(); | |||||
static UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>(); | |||||
bool HasAccumulator() | bool HasAccumulator() | ||||
{ | { | ||||
@@ -192,7 +192,7 @@ namespace Tensorflow.Eager | |||||
ArgDef input_arg, | ArgDef input_arg, | ||||
List<object> flattened_attrs, | List<object> flattened_attrs, | ||||
List<Tensor> flattened_inputs, | List<Tensor> flattened_inputs, | ||||
IntPtr op, | |||||
SafeOpHandle op, | |||||
Status status) | Status status) | ||||
{ | { | ||||
IntPtr input_handle; | IntPtr input_handle; | ||||
@@ -224,7 +224,7 @@ namespace Tensorflow.Eager | |||||
return true; | return true; | ||||
} | } | ||||
public void SetOpAttrs(TFE_Op op, params object[] attrs) | |||||
public void SetOpAttrs(SafeOpHandle op, params object[] attrs) | |||||
{ | { | ||||
var status = tf.status; | var status = tf.status; | ||||
var len = attrs.Length; | var len = attrs.Length; | ||||
@@ -257,7 +257,7 @@ namespace Tensorflow.Eager | |||||
/// <param name="attr_value"></param> | /// <param name="attr_value"></param> | ||||
/// <param name="attr_list_sizes"></param> | /// <param name="attr_list_sizes"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||||
void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr, | |||||
string attr_name, object attr_value, | string attr_name, object attr_value, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -290,7 +290,7 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
} | } | ||||
bool SetOpAttrList(Context ctx, IntPtr op, | |||||
bool SetOpAttrList(Context ctx, SafeOpHandle op, | |||||
string key, object value, TF_AttrType type, | string key, object value, TF_AttrType type, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -298,7 +298,7 @@ namespace Tensorflow.Eager | |||||
return false; | return false; | ||||
} | } | ||||
bool SetOpAttrScalar(Context ctx, IntPtr op, | |||||
bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | |||||
string key, object value, TF_AttrType type, | string key, object value, TF_AttrType type, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -0,0 +1,40 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public sealed class SafeOpHandle : SafeTensorflowHandle | |||||
{ | |||||
private SafeOpHandle() | |||||
{ | |||||
} | |||||
public SafeOpHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TFE_DeleteOp(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -1,23 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public struct TFE_Op | |||||
{ | |||||
IntPtr _handle; | |||||
public TFE_Op(IntPtr handle) | |||||
=> _handle = handle; | |||||
public static implicit operator TFE_Op(IntPtr handle) | |||||
=> new TFE_Op(handle); | |||||
public static implicit operator IntPtr(TFE_Op tensor) | |||||
=> tensor._handle; | |||||
public override string ToString() | |||||
=> $"TFE_Op 0x{_handle.ToString("x16")}"; | |||||
} | |||||
} |
@@ -30,7 +30,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
/// <param name="input_name">const char*</param> | /// <param name="input_name">const char*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status); | |||||
public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the length (number of tensors) of the output argument `output_name` | /// Returns the length (number of tensors) of the output argument `output_name` | ||||
@@ -54,7 +54,7 @@ namespace Tensorflow | |||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status); | |||||
public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||||
public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -98,7 +98,7 @@ namespace Tensorflow | |||||
/// <param name="num_retvals">int*</param> | /// <param name="num_retvals">int*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||||
public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||||
public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | ||||
@@ -124,7 +124,7 @@ namespace Tensorflow | |||||
/// <param name="raw_device_name">const char*</param> | /// <param name="raw_device_name">const char*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||||
public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -140,10 +140,10 @@ namespace Tensorflow | |||||
/// <param name="attr_name">const char*</param> | /// <param name="attr_name">const char*</param> | ||||
/// <param name="value">TF_DataType</param> | /// <param name="value">TF_DataType</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); | |||||
public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); | |||||
public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -154,10 +154,10 @@ namespace Tensorflow | |||||
/// <param name="num_dims">const int</param> | /// <param name="num_dims">const int</param> | ||||
/// <param name="out_status">TF_Status*</param> | /// <param name="out_status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||||
public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrBool(IntPtr op, string attr_name, bool value); | |||||
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -167,7 +167,7 @@ namespace Tensorflow | |||||
/// <param name="value">const void*</param> | /// <param name="value">const void*</param> | ||||
/// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); | |||||
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -176,7 +176,7 @@ namespace Tensorflow | |||||
/// <param name="device_name"></param> | /// <param name="device_name"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetDevice(TFE_Op op, string device_name, SafeStatusHandle status); | |||||
public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -185,7 +185,7 @@ namespace Tensorflow | |||||
/// <param name="h">TFE_TensorHandle*</param> | /// <param name="h">TFE_TensorHandle*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status); | |||||
public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -82,25 +82,25 @@ namespace TensorFlowNET.UnitTest | |||||
protected ulong TF_TensorByteSize(IntPtr t) | protected ulong TF_TensorByteSize(IntPtr t) | ||||
=> c_api.TF_TensorByteSize(t); | => c_api.TF_TensorByteSize(t); | ||||
protected void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status) | |||||
protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status) | |||||
=> c_api.TFE_OpAddInput(op, h, status); | => c_api.TFE_OpAddInput(op, h, status); | ||||
protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value) | |||||
protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | |||||
=> c_api.TFE_OpSetAttrType(op, attr_name, value); | => c_api.TFE_OpSetAttrType(op, attr_name, value); | ||||
protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||||
protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||||
=> c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | ||||
protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | |||||
protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length) | |||||
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); | => c_api.TFE_OpSetAttrString(op, attr_name, value, length); | ||||
protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||||
protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | ||||
=> c_api.TFE_NewTensorHandle(t, status); | => c_api.TFE_NewTensorHandle(t, status); | ||||
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||||
protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||||
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status); | => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | ||||
protected SafeContextOptionsHandle TFE_NewContextOptions() | protected SafeContextOptionsHandle TFE_NewContextOptions() | ||||
@@ -109,21 +109,18 @@ namespace TensorFlowNET.UnitTest | |||||
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | ||||
=> c_api.TFE_NewContext(opts, status); | => c_api.TFE_NewContext(opts, status); | ||||
protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | |||||
protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||||
=> c_api.TFE_OpGetInputLength(op, input_name, status); | => c_api.TFE_OpGetInputLength(op, input_name, status); | ||||
protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||||
protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||||
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | ||||
protected int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status) | |||||
protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||||
=> c_api.TFE_OpGetOutputLength(op, input_name, status); | => c_api.TFE_OpGetOutputLength(op, input_name, status); | ||||
protected void TFE_DeleteTensorHandle(IntPtr h) | protected void TFE_DeleteTensorHandle(IntPtr h) | ||||
=> c_api.TFE_DeleteTensorHandle(h); | => c_api.TFE_DeleteTensorHandle(h); | ||||
protected void TFE_DeleteOp(IntPtr op) | |||||
=> c_api.TFE_DeleteOp(op); | |||||
protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | ||||
=> c_api.TFE_ContextGetExecutorForThread(ctx); | => c_api.TFE_ContextGetExecutorForThread(ctx); | ||||
@@ -154,7 +151,7 @@ namespace TensorFlowNET.UnitTest | |||||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | ||||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | |||||
protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | |||||
=> c_api.TFE_OpSetDevice(op, device_name, status); | => c_api.TFE_OpSetDevice(op, device_name, status); | ||||
} | } | ||||
} | } |
@@ -34,13 +34,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var m = TestMatrixTensorHandle(); | var m = TestMatrixTensorHandle(); | ||||
var matmul = MatMulOp(ctx, m, m); | |||||
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | ||||
int num_retvals = 2; | |||||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(1, num_retvals); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteOp(matmul); | |||||
using (var matmul = MatMulOp(ctx, m, m)) | |||||
{ | |||||
int num_retvals = 2; | |||||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(1, num_retvals); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
} | |||||
TFE_DeleteTensorHandle(m); | TFE_DeleteTensorHandle(m); | ||||
t = TFE_TensorHandleResolve(retvals[0], status); | t = TFE_TensorHandleResolve(retvals[0], status); | ||||
@@ -26,37 +26,38 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
var input1 = TestMatrixTensorHandle(); | var input1 = TestMatrixTensorHandle(); | ||||
var input2 = TestMatrixTensorHandle(); | var input2 = TestMatrixTensorHandle(); | ||||
var identityOp = TFE_NewOp(ctx, "IdentityN", status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var retvals = new IntPtr[2]; | |||||
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | |||||
{ | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
// Try to retrieve lengths before building the attributes (should fail) | |||||
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
// Try to retrieve lengths before building the attributes (should fail) | |||||
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var inputs = new IntPtr[] { input1, input2 }; | |||||
TFE_OpAddInputList(identityOp, inputs, 2, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var inputs = new IntPtr[] { input1, input2 }; | |||||
TFE_OpAddInputList(identityOp, inputs, 2, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
// Try to retrieve lengths before executing the op (should work) | |||||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
// Try to retrieve lengths before executing the op (should work) | |||||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var retvals = new IntPtr[2]; | |||||
int num_retvals = 2; | |||||
TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
int num_retvals = 2; | |||||
TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
// Try to retrieve lengths after executing the op (should work) | |||||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
// Try to retrieve lengths after executing the op (should work) | |||||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
} | |||||
TFE_DeleteOp(identityOp); | |||||
TFE_DeleteTensorHandle(input1); | TFE_DeleteTensorHandle(input1); | ||||
TFE_DeleteTensorHandle(input2); | TFE_DeleteTensorHandle(input2); | ||||
TFE_DeleteTensorHandle(retvals[0]); | TFE_DeleteTensorHandle(retvals[0]); | ||||
@@ -27,27 +27,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
var condition = TestScalarTensorHandle(true); | var condition = TestScalarTensorHandle(true); | ||||
var t1 = TestMatrixTensorHandle(); | var t1 = TestMatrixTensorHandle(); | ||||
var t2 = TestAxisTensorHandle(); | var t2 = TestAxisTensorHandle(); | ||||
var assertOp = TFE_NewOp(ctx, "Assert", status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_OpAddInput(assertOp, condition, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var data = new[] { condition, t1, t2 }; | |||||
TFE_OpAddInputList(assertOp, data, 3, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var retvals = new IntPtr[1]; | |||||
using (var assertOp = TFE_NewOp(ctx, "Assert", status)) | |||||
{ | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_OpAddInput(assertOp, condition, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var data = new[] { condition, t1, t2 }; | |||||
TFE_OpAddInputList(assertOp, data, 3, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
/*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||||
var attr_found = attr_values.First(x => x.Name == "T"); | |||||
EXPECT_NE(attr_found, attr_values.Last());*/ | |||||
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||||
/*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||||
var attr_found = attr_values.First(x => x.Name == "T"); | |||||
EXPECT_NE(attr_found, attr_values.Last());*/ | |||||
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||||
var retvals = new IntPtr[1]; | |||||
int num_retvals = 1; | |||||
TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
int num_retvals = 1; | |||||
TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
} | |||||
TFE_DeleteOp(assertOp); | |||||
TFE_DeleteTensorHandle(condition); | TFE_DeleteTensorHandle(condition); | ||||
TFE_DeleteTensorHandle(t1); | TFE_DeleteTensorHandle(t1); | ||||
TFE_DeleteTensorHandle(t2); | TFE_DeleteTensorHandle(t2); | ||||
@@ -40,25 +40,26 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | ||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ||||
var shape_op = ShapeOp(ctx, hgpu); | |||||
TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
var retvals = new IntPtr[1]; | var retvals = new IntPtr[1]; | ||||
int num_retvals = 1; | |||||
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
using (var shape_op = ShapeOp(ctx, hgpu)) | |||||
{ | |||||
TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
int num_retvals = 1; | |||||
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
// .device of shape is GPU since the op is executed on GPU | |||||
device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
ASSERT_TRUE(device_name.Contains("GPU:0")); | |||||
// .device of shape is GPU since the op is executed on GPU | |||||
device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
ASSERT_TRUE(device_name.Contains("GPU:0")); | |||||
// .backing_device of shape is CPU since the tensor is backed by CPU | |||||
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
// .backing_device of shape is CPU since the tensor is backed by CPU | |||||
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
} | |||||
TFE_DeleteOp(shape_op); | |||||
TFE_DeleteTensorHandle(retvals[0]); | TFE_DeleteTensorHandle(retvals[0]); | ||||
TFE_DeleteTensorHandle(hgpu); | TFE_DeleteTensorHandle(hgpu); | ||||
} | } | ||||
@@ -28,15 +28,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
var var_handle = CreateVariable(ctx, 12.0f, status); | var var_handle = CreateVariable(ctx, 12.0f, status); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var op = TFE_NewOp(ctx, "ReadVariableOp", status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
TFE_OpAddInput(op, var_handle, status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
int num_retvals = 1; | int num_retvals = 1; | ||||
var value_handle = new[] { IntPtr.Zero }; | var value_handle = new[] { IntPtr.Zero }; | ||||
TFE_Execute(op, value_handle, ref num_retvals, status); | |||||
TFE_DeleteOp(op); | |||||
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | |||||
{ | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
TFE_OpAddInput(op, var_handle, status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_Execute(op, value_handle, ref num_retvals, status); | |||||
} | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
ASSERT_EQ(1, num_retvals); | ASSERT_EQ(1, num_retvals); | ||||
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
return th; | return th; | ||||
} | } | ||||
IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||||
SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
return false; | return false; | ||||
} | } | ||||
IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) | |||||
SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -79,39 +79,43 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | ||||
{ | { | ||||
var op = TFE_NewOp(ctx, "VarHandleOp", status); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||||
TFE_OpSetAttrString(op, "container", "", 0); | |||||
TFE_OpSetAttrString(op, "shared_name", "", 0); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
var var_handle = new IntPtr[1]; | var var_handle = new IntPtr[1]; | ||||
int num_retvals = 1; | int num_retvals = 1; | ||||
TFE_Execute(op, var_handle, ref num_retvals, status); | |||||
TFE_DeleteOp(op); | |||||
using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | |||||
{ | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||||
TFE_OpSetAttrString(op, "container", "", 0); | |||||
TFE_OpSetAttrString(op, "shared_name", "", 0); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_Execute(op, var_handle, ref num_retvals, status); | |||||
} | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | ||||
CHECK_EQ(1, num_retvals); | CHECK_EQ(1, num_retvals); | ||||
// Assign 'value' to it. | // Assign 'value' to it. | ||||
op = TFE_NewOp(ctx, "AssignVariableOp", status); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
TFE_OpAddInput(op, var_handle[0], status); | |||||
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | |||||
{ | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
TFE_OpAddInput(op, var_handle[0], status); | |||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||||
var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); | |||||
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | |||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||||
var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); | |||||
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | |||||
var value_handle = c_api.TFE_NewTensorHandle(t, status); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
var value_handle = c_api.TFE_NewTensorHandle(t, status); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_OpAddInput(op, value_handle, status); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
TFE_OpAddInput(op, value_handle, status); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
num_retvals = 0; | |||||
c_api.TFE_Execute(op, null, ref num_retvals, status); | |||||
} | |||||
num_retvals = 0; | |||||
c_api.TFE_Execute(op, null, ref num_retvals, status); | |||||
TFE_DeleteOp(op); | |||||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | ||||
CHECK_EQ(0, num_retvals); | CHECK_EQ(0, num_retvals); | ||||