Browse Source

Implement SafeContextOptionsHandle as a wrapper for TFE_ContextOptions

tags/v0.20
Sam Harwell 5 years ago
parent
commit
6d73b8d61e
12 changed files with 121 additions and 75 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Eager/Context.cs
  2. +25
    -18
      src/TensorFlowNET.Core/Eager/ContextOptions.cs
  3. +40
    -0
      src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs
  4. +0
    -23
      src/TensorFlowNET.Core/Eager/TFE_ContextOptions.cs
  5. +3
    -3
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  6. +2
    -5
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
  7. +8
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
  8. +9
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
  9. +8
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
  10. +8
    -6
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  11. +8
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
  12. +9
    -3
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs

+ 1
- 1
src/TensorFlowNET.Core/Eager/Context.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow.Eager

public Context(ContextOptions opts, Status status)
{
Handle = c_api.TFE_NewContext(opts, status.Handle);
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
status.Check(true);
}



+ 25
- 18
src/TensorFlowNET.Core/Eager/ContextOptions.cs View File

@@ -1,26 +1,33 @@
using System;
using System.IO;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;

namespace Tensorflow.Eager
{
public class ContextOptions : DisposableObject
public sealed class ContextOptions : IDisposable
{
public ContextOptions() : base(c_api.TFE_NewContextOptions())
{ }
public SafeContextOptionsHandle Handle { get; }

/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected sealed override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TFE_DeleteContextOptions(_handle);
public ContextOptions()
{
Handle = c_api.TFE_NewContextOptions();
}


public static implicit operator IntPtr(ContextOptions opts)
=> opts._handle;

public static implicit operator TFE_ContextOptions(ContextOptions opts)
=> new TFE_ContextOptions(opts._handle);
public void Dispose()
=> Handle.Dispose();
}

}

+ 40
- 0
src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow.Eager
{
public sealed class SafeContextOptionsHandle : SafeTensorflowHandle
{
public SafeContextOptionsHandle()
{
}

public SafeContextOptionsHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TFE_DeleteContextOptions(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 0
- 23
src/TensorFlowNET.Core/Eager/TFE_ContextOptions.cs View File

@@ -1,23 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public struct TFE_ContextOptions
{
IntPtr _handle;

public TFE_ContextOptions(IntPtr handle)
=> _handle = handle;

public static implicit operator TFE_ContextOptions(IntPtr handle)
=> new TFE_ContextOptions(handle);

public static implicit operator IntPtr(TFE_ContextOptions tensor)
=> tensor._handle;

public override string ToString()
=> $"TFE_ContextOptions {_handle}";
}
}

+ 3
- 3
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -52,7 +52,7 @@ namespace Tensorflow
/// </summary>
/// <returns>TFE_ContextOptions*</returns>
[DllImport(TensorFlowLibName)]
public static extern TFE_ContextOptions TFE_NewContextOptions();
public static extern SafeContextOptionsHandle TFE_NewContextOptions();

/// <summary>
/// Destroy an options object.
@@ -114,7 +114,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TFE_Context*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status);
public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextStartStep(SafeContextHandle ctx);
@@ -254,7 +254,7 @@ namespace Tensorflow
/// <param name="opts">TFE_ContextOptions*</param>
/// <param name="enable">unsigned char</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetAsync(IntPtr opts, byte enable);
public static extern void TFE_ContextOptionsSetAsync(SafeContextOptionsHandle opts, byte enable);

/// <summary>
///


+ 2
- 5
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -102,15 +102,12 @@ namespace TensorFlowNET.UnitTest
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status)
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status);

protected IntPtr TFE_NewContextOptions()
protected SafeContextOptionsHandle TFE_NewContextOptions()
=> c_api.TFE_NewContextOptions();

protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status)
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status)
=> c_api.TFE_NewContext(opts, status);

protected void TFE_DeleteContextOptions(IntPtr opts)
=> c_api.TFE_DeleteContextOptions(opts);

protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status)
=> c_api.TFE_OpGetInputLength(op, input_name, status);



+ 8
- 4
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
{
@@ -13,13 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI
public void Context()
{
using var status = c_api.TF_NewStatus();
var opts = c_api.TFE_NewContextOptions();

IntPtr devices;
using (var ctx = c_api.TFE_NewContext(opts, status))
static SafeContextHandle NewContext(SafeStatusHandle status)
{
c_api.TFE_DeleteContextOptions(opts);
using var opts = c_api.TFE_NewContextOptions();
return c_api.TFE_NewContext(opts, status);
}

IntPtr devices;
using (var ctx = NewContext(status))
{
devices = c_api.TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}


+ 9
- 4
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
@@ -19,14 +20,18 @@ namespace TensorFlowNET.UnitTest.NativeAPI
unsafe void Execute_MatMul_CPU(bool async)
{
using var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async));

static SafeContextHandle NewContext(bool async, SafeStatusHandle status)
{
using var opts = c_api.TFE_NewContextOptions();
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async));
return c_api.TFE_NewContext(opts, status);
}

IntPtr t;
using (var ctx = TFE_NewContext(opts, status))
using (var ctx = NewContext(async, status))
{
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var m = TestMatrixTensorHandle();
var matmul = MatMulOp(ctx, m, m);


+ 8
- 4
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -2,7 +2,6 @@
using System;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;

namespace TensorFlowNET.UnitTest.NativeAPI
{
@@ -15,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI
public unsafe void OpGetInputAndOutputLengths()
{
using var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
using var ctx = TFE_NewContext(opts, status);

static SafeContextHandle NewContext(SafeStatusHandle status)
{
using var opts = c_api.TFE_NewContextOptions();
return c_api.TFE_NewContext(opts, status);
}

using var ctx = NewContext(status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var input1 = TestMatrixTensorHandle();
var input2 = TestMatrixTensorHandle();


+ 8
- 6
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -1,10 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;
using System.Linq;

namespace TensorFlowNET.UnitTest.NativeAPI
{
@@ -17,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI
public unsafe void OpInferMixedTypeInputListAttrs()
{
using var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
using var ctx = TFE_NewContext(opts, status);

static SafeContextHandle NewContext(SafeStatusHandle status)
{
using var opts = c_api.TFE_NewContextOptions();
return c_api.TFE_NewContext(opts, status);
}

using var ctx = NewContext(status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var condition = TestScalarTensorHandle(true);
var t1 = TestMatrixTensorHandle();


+ 8
- 4
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -2,7 +2,6 @@
using System;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;

namespace TensorFlowNET.UnitTest.NativeAPI
{
@@ -15,9 +14,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI
public unsafe void TensorHandleDevices()
{
var status = c_api.TF_NewStatus();
var opts = TFE_NewContextOptions();
using var ctx = TFE_NewContext(opts, status);
TFE_DeleteContextOptions(opts);

static SafeContextHandle NewContext(SafeStatusHandle status)
{
using var opts = c_api.TFE_NewContextOptions();
return c_api.TFE_NewContext(opts, status);
}

using var ctx = NewContext(status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var hcpu = TestMatrixTensorHandle();


+ 9
- 3
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
@@ -14,10 +15,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI
public unsafe void Variables()
{
using var status = c_api.TF_NewStatus();
var opts = TFE_NewContextOptions();
using var ctx = TFE_NewContext(opts, status);

static SafeContextHandle NewContext(SafeStatusHandle status)
{
using var opts = c_api.TFE_NewContextOptions();
return c_api.TFE_NewContext(opts, status);
}

using var ctx = NewContext(status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var var_handle = CreateVariable(ctx, 12.0f, status);
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));


Loading…
Cancel
Save