@@ -91,6 +91,18 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_DeleteContext(IntPtr ctx); | public static extern void TFE_DeleteContext(IntPtr ctx); | ||||
/// <summary> | |||||
/// Execute the operation defined by <paramref name="op"/> and return handles to computed | |||||
/// tensors in <paramref name="retvals"/>. | |||||
/// </summary> | |||||
/// <remarks> | |||||
/// Upon successful return, the first <paramref name="num_retvals"/> slots in <paramref name="retvals"/> will | |||||
/// contain handle instances which the caller is responsible for disposing once they are no longer in use. | |||||
/// </remarks> | |||||
/// <param name="op"></param> | |||||
/// <param name="retvals"></param> | |||||
/// <param name="num_retvals"></param> | |||||
/// <param name="status"></param> | |||||
public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | ||||
{ | { | ||||
unsafe | unsafe | ||||
@@ -100,6 +112,9 @@ namespace Tensorflow | |||||
TFE_Execute(op, rawReturns, ref num_retvals, status); | TFE_Execute(op, rawReturns, ref num_retvals, status); | ||||
for (var i = 0; i < num_retvals; i++) | for (var i = 0; i < num_retvals; i++) | ||||
{ | { | ||||
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be | |||||
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return | |||||
// values. | |||||
retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); | retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); | ||||
} | } | ||||
} | } | ||||
@@ -34,23 +34,23 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var retvals = new SafeTensorHandleHandle[2]; | var retvals = new SafeTensorHandleHandle[2]; | ||||
try | |||||
using (var m = TestMatrixTensorHandle()) | |||||
using (var matmul = MatMulOp(ctx, m, m)) | |||||
{ | { | ||||
using (var m = TestMatrixTensorHandle()) | |||||
using (var matmul = MatMulOp(ctx, m, m)) | |||||
{ | |||||
int num_retvals; | |||||
c_api.TFE_Execute(matmul, retvals, out num_retvals, status); | |||||
EXPECT_EQ(1, num_retvals); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
} | |||||
int num_retvals; | |||||
c_api.TFE_Execute(matmul, retvals, out num_retvals, status); | |||||
EXPECT_EQ(1, num_retvals); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
} | |||||
try | |||||
{ | |||||
t = TFE_TensorHandleResolve(retvals[0], status); | t = TFE_TensorHandleResolve(retvals[0], status); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
retvals[0]?.Dispose(); | |||||
retvals[0].Dispose(); | |||||
} | } | ||||
} | } | ||||
@@ -51,6 +51,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
int num_retvals; | int num_retvals; | ||||
TFE_Execute(identityOp, retvals, out num_retvals, status); | TFE_Execute(identityOp, retvals, out num_retvals, status); | ||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
EXPECT_EQ(2, num_retvals); | |||||
try | try | ||||
{ | { | ||||
@@ -62,8 +63,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
retvals[0]?.Dispose(); | |||||
retvals[1]?.Dispose(); | |||||
retvals[0].Dispose(); | |||||
retvals[1].Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -46,8 +46,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
int num_retvals; | int num_retvals; | ||||
TFE_Execute(assertOp, retvals, out num_retvals, status); | TFE_Execute(assertOp, retvals, out num_retvals, status); | ||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
EXPECT_EQ(1, num_retvals); | |||||
retvals[0]?.Dispose(); | |||||
retvals[0].Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -49,6 +49,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
int num_retvals; | int num_retvals; | ||||
c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); | c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); | ||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ||||
ASSERT_EQ(1, num_retvals); | |||||
try | try | ||||
{ | { | ||||
@@ -64,7 +65,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
retvals[0]?.Dispose(); | |||||
retvals[0].Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -38,6 +38,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_OpAddInput(op, var_handle, status); | TFE_OpAddInput(op, var_handle, status); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_Execute(op, value_handle, out num_retvals, status); | TFE_Execute(op, value_handle, out num_retvals, status); | ||||
ASSERT_EQ(1, num_retvals); | |||||
} | } | ||||
try | try | ||||
@@ -57,7 +58,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
value_handle[0]?.Dispose(); | |||||
value_handle[0].Dispose(); | |||||
} | } | ||||
} | } | ||||
@@ -90,11 +90,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_OpSetAttrString(op, "shared_name", "", 0); | TFE_OpSetAttrString(op, "shared_name", "", 0); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | ||||
TFE_Execute(op, var_handle, out num_retvals, status); | TFE_Execute(op, var_handle, out num_retvals, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
CHECK_EQ(1, num_retvals); | |||||
} | } | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
CHECK_EQ(1, num_retvals); | |||||
// Assign 'value' to it. | // Assign 'value' to it. | ||||
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | ||||
{ | { | ||||
@@ -112,13 +111,11 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
TFE_OpAddInput(op, value_handle, status); | TFE_OpAddInput(op, value_handle, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | ||||
num_retvals = 0; | |||||
c_api.TFE_Execute(op, null, out num_retvals, status); | c_api.TFE_Execute(op, null, out num_retvals, status); | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
CHECK_EQ(0, num_retvals); | |||||
} | } | ||||
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
CHECK_EQ(0, num_retvals); | |||||
return var_handle[0]; | return var_handle[0]; | ||||
} | } | ||||