Browse Source

Simplify code using SafeTensorHandleHandle

tags/v0.20
Sam Harwell Haiping 5 years ago
parent
commit
941a241055
7 changed files with 38 additions and 22 deletions
  1. +15
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  2. +10
    -10
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
  3. +3
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
  4. +2
    -1
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  5. +2
    -1
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
  6. +2
    -1
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs
  7. +4
    -7
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs

+ 15
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -91,6 +91,18 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContext(IntPtr ctx);

/// <summary>
/// Execute the operation defined by <paramref name="op"/> and return handles to computed
/// tensors in <paramref name="retvals"/>.
/// </summary>
/// <remarks>
/// Upon successful return, the first <paramref name="num_retvals"/> slots in <paramref name="retvals"/> will
/// contain handle instances which the caller is responsible for disposing once they are no longer in use.
/// </remarks>
/// <param name="op"></param>
/// <param name="retvals"></param>
/// <param name="num_retvals"></param>
/// <param name="status"></param>
public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status)
{
unsafe
@@ -100,6 +112,9 @@ namespace Tensorflow
TFE_Execute(op, rawReturns, ref num_retvals, status);
for (var i = 0; i < num_retvals; i++)
{
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return
// values.
retvals[i] = new SafeTensorHandleHandle(rawReturns[i]);
}
}


+ 10
- 10
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs View File

@@ -34,23 +34,23 @@ namespace TensorFlowNET.UnitTest.NativeAPI
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var retvals = new SafeTensorHandleHandle[2];
try
using (var m = TestMatrixTensorHandle())
using (var matmul = MatMulOp(ctx, m, m))
{
using (var m = TestMatrixTensorHandle())
using (var matmul = MatMulOp(ctx, m, m))
{
int num_retvals;
c_api.TFE_Execute(matmul, retvals, out num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
int num_retvals;
c_api.TFE_Execute(matmul, retvals, out num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

try
{
t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
finally
{
retvals[0]?.Dispose();
retvals[0].Dispose();
}
}



+ 3
- 2
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -51,6 +51,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
int num_retvals;
TFE_Execute(identityOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, num_retvals);

try
{
@@ -62,8 +63,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI
}
finally
{
retvals[0]?.Dispose();
retvals[1]?.Dispose();
retvals[0].Dispose();
retvals[1].Dispose();
}
}
}


+ 2
- 1
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -46,8 +46,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI
int num_retvals;
TFE_Execute(assertOp, retvals, out num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(1, num_retvals);

retvals[0]?.Dispose();
retvals[0].Dispose();
}
}
}


+ 2
- 1
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -49,6 +49,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
int num_retvals;
c_api.TFE_Execute(shape_op, retvals, out num_retvals, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
ASSERT_EQ(1, num_retvals);

try
{
@@ -64,7 +65,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
}
finally
{
retvals[0]?.Dispose();
retvals[0].Dispose();
}
}
}


+ 2
- 1
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs View File

@@ -38,6 +38,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
TFE_OpAddInput(op, var_handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_Execute(op, value_handle, out num_retvals, status);
ASSERT_EQ(1, num_retvals);
}

try
@@ -57,7 +58,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
}
finally
{
value_handle[0]?.Dispose();
value_handle[0].Dispose();
}
}



+ 4
- 7
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs View File

@@ -90,11 +90,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
TFE_Execute(op, var_handle, out num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
CHECK_EQ(1, num_retvals);
}

if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
CHECK_EQ(1, num_retvals);

// Assign 'value' to it.
using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
{
@@ -112,13 +111,11 @@ namespace TensorFlowNET.UnitTest.NativeAPI
TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);

num_retvals = 0;
c_api.TFE_Execute(op, null, out num_retvals, status);
if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
CHECK_EQ(0, num_retvals);
}

if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
CHECK_EQ(0, num_retvals);

return var_handle[0];
}



Loading…
Cancel
Save