@@ -145,7 +145,7 @@ namespace Tensorflow.Eager | |||
return flat_result; | |||
} | |||
TFE_Op GetOp(Context ctx, string op_or_function_name, Status status) | |||
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
{ | |||
if (thread_local_eager_operation_map.find(ctx, out var op)) | |||
c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle); | |||
@@ -159,7 +159,7 @@ namespace Tensorflow.Eager | |||
return op; | |||
} | |||
static UnorderedMap<Context, TFE_Op> thread_local_eager_operation_map = new UnorderedMap<Context, TFE_Op>(); | |||
static UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>(); | |||
bool HasAccumulator() | |||
{ | |||
@@ -192,7 +192,7 @@ namespace Tensorflow.Eager | |||
ArgDef input_arg, | |||
List<object> flattened_attrs, | |||
List<Tensor> flattened_inputs, | |||
IntPtr op, | |||
SafeOpHandle op, | |||
Status status) | |||
{ | |||
IntPtr input_handle; | |||
@@ -224,7 +224,7 @@ namespace Tensorflow.Eager | |||
return true; | |||
} | |||
public void SetOpAttrs(TFE_Op op, params object[] attrs) | |||
public void SetOpAttrs(SafeOpHandle op, params object[] attrs) | |||
{ | |||
var status = tf.status; | |||
var len = attrs.Length; | |||
@@ -257,7 +257,7 @@ namespace Tensorflow.Eager | |||
/// <param name="attr_value"></param> | |||
/// <param name="attr_list_sizes"></param> | |||
/// <param name="status"></param> | |||
void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||
void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr, | |||
string attr_name, object attr_value, | |||
Dictionary<string, long> attr_list_sizes, | |||
Status status) | |||
@@ -290,7 +290,7 @@ namespace Tensorflow.Eager | |||
} | |||
} | |||
bool SetOpAttrList(Context ctx, IntPtr op, | |||
bool SetOpAttrList(Context ctx, SafeOpHandle op, | |||
string key, object value, TF_AttrType type, | |||
Dictionary<string, long> attr_list_sizes, | |||
Status status) | |||
@@ -298,7 +298,7 @@ namespace Tensorflow.Eager | |||
return false; | |||
} | |||
bool SetOpAttrScalar(Context ctx, IntPtr op, | |||
bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | |||
string key, object value, TF_AttrType type, | |||
Dictionary<string, long> attr_list_sizes, | |||
Status status) | |||
@@ -0,0 +1,40 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Eager | |||
{ | |||
public sealed class SafeOpHandle : SafeTensorflowHandle | |||
{ | |||
private SafeOpHandle() | |||
{ | |||
} | |||
public SafeOpHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TFE_DeleteOp(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} | |||
} |
@@ -1,23 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Eager | |||
{ | |||
public struct TFE_Op | |||
{ | |||
IntPtr _handle; | |||
public TFE_Op(IntPtr handle) | |||
=> _handle = handle; | |||
public static implicit operator TFE_Op(IntPtr handle) | |||
=> new TFE_Op(handle); | |||
public static implicit operator IntPtr(TFE_Op tensor) | |||
=> tensor._handle; | |||
public override string ToString() | |||
=> $"TFE_Op 0x{_handle.ToString("x16")}"; | |||
} | |||
} |
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||
/// <param name="input_name">const char*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status); | |||
public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Returns the length (number of tensors) of the output argument `output_name` | |||
@@ -54,7 +54,7 @@ namespace Tensorflow | |||
/// <param name="status"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status); | |||
public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -65,7 +65,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||
public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -98,7 +98,7 @@ namespace Tensorflow | |||
/// <param name="num_retvals">int*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||
public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | |||
@@ -124,7 +124,7 @@ namespace Tensorflow | |||
/// <param name="raw_device_name">const char*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||
public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -140,10 +140,10 @@ namespace Tensorflow | |||
/// <param name="attr_name">const char*</param> | |||
/// <param name="value">TF_DataType</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); | |||
public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); | |||
public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | |||
/// <summary> | |||
/// | |||
@@ -154,10 +154,10 @@ namespace Tensorflow | |||
/// <param name="num_dims">const int</param> | |||
/// <param name="out_status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||
public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrBool(IntPtr op, string attr_name, bool value); | |||
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | |||
/// <summary> | |||
/// | |||
@@ -167,7 +167,7 @@ namespace Tensorflow | |||
/// <param name="value">const void*</param> | |||
/// <param name="length">size_t</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); | |||
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | |||
/// <summary> | |||
/// | |||
@@ -176,7 +176,7 @@ namespace Tensorflow | |||
/// <param name="device_name"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetDevice(TFE_Op op, string device_name, SafeStatusHandle status); | |||
public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -185,7 +185,7 @@ namespace Tensorflow | |||
/// <param name="h">TFE_TensorHandle*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status); | |||
public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -82,25 +82,25 @@ namespace TensorFlowNET.UnitTest | |||
protected ulong TF_TensorByteSize(IntPtr t) | |||
=> c_api.TF_TensorByteSize(t); | |||
protected void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status) | |||
protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status) | |||
=> c_api.TFE_OpAddInput(op, h, status); | |||
protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value) | |||
protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | |||
=> c_api.TFE_OpSetAttrType(op, attr_name, value); | |||
protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||
protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||
=> c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | |||
protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | |||
protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length) | |||
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); | |||
protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | |||
=> c_api.TFE_NewTensorHandle(t, status); | |||
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||
protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status); | |||
protected SafeContextOptionsHandle TFE_NewContextOptions() | |||
@@ -109,21 +109,18 @@ namespace TensorFlowNET.UnitTest | |||
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | |||
=> c_api.TFE_NewContext(opts, status); | |||
protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | |||
protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||
=> c_api.TFE_OpGetInputLength(op, input_name, status); | |||
protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||
protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | |||
protected int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status) | |||
protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||
=> c_api.TFE_OpGetOutputLength(op, input_name, status); | |||
protected void TFE_DeleteTensorHandle(IntPtr h) | |||
=> c_api.TFE_DeleteTensorHandle(h); | |||
protected void TFE_DeleteOp(IntPtr op) | |||
=> c_api.TFE_DeleteOp(op); | |||
protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||
=> c_api.TFE_ContextGetExecutorForThread(ctx); | |||
@@ -154,7 +151,7 @@ namespace TensorFlowNET.UnitTest | |||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||
protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | |||
protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | |||
=> c_api.TFE_OpSetDevice(op, device_name, status); | |||
} | |||
} |
@@ -34,13 +34,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var m = TestMatrixTensorHandle(); | |||
var matmul = MatMulOp(ctx, m, m); | |||
var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||
int num_retvals = 2; | |||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||
EXPECT_EQ(1, num_retvals); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteOp(matmul); | |||
using (var matmul = MatMulOp(ctx, m, m)) | |||
{ | |||
int num_retvals = 2; | |||
c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||
EXPECT_EQ(1, num_retvals); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
} | |||
TFE_DeleteTensorHandle(m); | |||
t = TFE_TensorHandleResolve(retvals[0], status); | |||
@@ -26,37 +26,38 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
var input1 = TestMatrixTensorHandle(); | |||
var input2 = TestMatrixTensorHandle(); | |||
var identityOp = TFE_NewOp(ctx, "IdentityN", status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var retvals = new IntPtr[2]; | |||
using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | |||
{ | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
// Try to retrieve lengths before building the attributes (should fail) | |||
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
// Try to retrieve lengths before building the attributes (should fail) | |||
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var inputs = new IntPtr[] { input1, input2 }; | |||
TFE_OpAddInputList(identityOp, inputs, 2, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var inputs = new IntPtr[] { input1, input2 }; | |||
TFE_OpAddInputList(identityOp, inputs, 2, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
// Try to retrieve lengths before executing the op (should work) | |||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
// Try to retrieve lengths before executing the op (should work) | |||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var retvals = new IntPtr[2]; | |||
int num_retvals = 2; | |||
TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
int num_retvals = 2; | |||
TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
// Try to retrieve lengths after executing the op (should work) | |||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
// Try to retrieve lengths after executing the op (should work) | |||
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
} | |||
TFE_DeleteOp(identityOp); | |||
TFE_DeleteTensorHandle(input1); | |||
TFE_DeleteTensorHandle(input2); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
@@ -27,27 +27,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
var condition = TestScalarTensorHandle(true); | |||
var t1 = TestMatrixTensorHandle(); | |||
var t2 = TestAxisTensorHandle(); | |||
var assertOp = TFE_NewOp(ctx, "Assert", status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_OpAddInput(assertOp, condition, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var data = new[] { condition, t1, t2 }; | |||
TFE_OpAddInputList(assertOp, data, 3, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var retvals = new IntPtr[1]; | |||
using (var assertOp = TFE_NewOp(ctx, "Assert", status)) | |||
{ | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_OpAddInput(assertOp, condition, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var data = new[] { condition, t1, t2 }; | |||
TFE_OpAddInputList(assertOp, data, 3, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
/*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||
var attr_found = attr_values.First(x => x.Name == "T"); | |||
EXPECT_NE(attr_found, attr_values.Last());*/ | |||
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||
/*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||
var attr_found = attr_values.First(x => x.Name == "T"); | |||
EXPECT_NE(attr_found, attr_values.Last());*/ | |||
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||
var retvals = new IntPtr[1]; | |||
int num_retvals = 1; | |||
TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
int num_retvals = 1; | |||
TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
} | |||
TFE_DeleteOp(assertOp); | |||
TFE_DeleteTensorHandle(condition); | |||
TFE_DeleteTensorHandle(t1); | |||
TFE_DeleteTensorHandle(t2); | |||
@@ -40,25 +40,26 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | |||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
var shape_op = ShapeOp(ctx, hgpu); | |||
TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
var retvals = new IntPtr[1]; | |||
int num_retvals = 1; | |||
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
using (var shape_op = ShapeOp(ctx, hgpu)) | |||
{ | |||
TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
int num_retvals = 1; | |||
c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
// .device of shape is GPU since the op is executed on GPU | |||
device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
ASSERT_TRUE(device_name.Contains("GPU:0")); | |||
// .device of shape is GPU since the op is executed on GPU | |||
device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
ASSERT_TRUE(device_name.Contains("GPU:0")); | |||
// .backing_device of shape is CPU since the tensor is backed by CPU | |||
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||
// .backing_device of shape is CPU since the tensor is backed by CPU | |||
backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||
} | |||
TFE_DeleteOp(shape_op); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
TFE_DeleteTensorHandle(hgpu); | |||
} | |||
@@ -28,15 +28,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
var var_handle = CreateVariable(ctx, 12.0f, status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var op = TFE_NewOp(ctx, "ReadVariableOp", status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
TFE_OpAddInput(op, var_handle, status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
int num_retvals = 1; | |||
var value_handle = new[] { IntPtr.Zero }; | |||
TFE_Execute(op, value_handle, ref num_retvals, status); | |||
TFE_DeleteOp(op); | |||
using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | |||
{ | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
TFE_OpAddInput(op, var_handle, status); | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_Execute(op, value_handle, ref num_retvals, status); | |||
} | |||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
ASSERT_EQ(1, num_retvals); | |||
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return th; | |||
} | |||
IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||
SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||
{ | |||
using var status = TF_NewStatus(); | |||
@@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return false; | |||
} | |||
IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) | |||
SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a) | |||
{ | |||
using var status = TF_NewStatus(); | |||
@@ -79,39 +79,43 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||
{ | |||
var op = TFE_NewOp(ctx, "VarHandleOp", status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||
TFE_OpSetAttrString(op, "container", "", 0); | |||
TFE_OpSetAttrString(op, "shared_name", "", 0); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
var var_handle = new IntPtr[1]; | |||
int num_retvals = 1; | |||
TFE_Execute(op, var_handle, ref num_retvals, status); | |||
TFE_DeleteOp(op); | |||
using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | |||
{ | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||
TFE_OpSetAttrString(op, "container", "", 0); | |||
TFE_OpSetAttrString(op, "shared_name", "", 0); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_Execute(op, var_handle, ref num_retvals, status); | |||
} | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
CHECK_EQ(1, num_retvals); | |||
// Assign 'value' to it. | |||
op = TFE_NewOp(ctx, "AssignVariableOp", status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
TFE_OpAddInput(op, var_handle[0], status); | |||
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | |||
{ | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
TFE_OpAddInput(op, var_handle[0], status); | |||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||
var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); | |||
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | |||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||
var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); | |||
tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | |||
var value_handle = c_api.TFE_NewTensorHandle(t, status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
var value_handle = c_api.TFE_NewTensorHandle(t, status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_OpAddInput(op, value_handle, status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
TFE_OpAddInput(op, value_handle, status); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
num_retvals = 0; | |||
c_api.TFE_Execute(op, null, ref num_retvals, status); | |||
} | |||
num_retvals = 0; | |||
c_api.TFE_Execute(op, null, ref num_retvals, status); | |||
TFE_DeleteOp(op); | |||
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
CHECK_EQ(0, num_retvals); | |||