@@ -16,7 +16,7 @@ namespace Tensorflow.Eager | |||||
public Context(ContextOptions opts, Status status) | public Context(ContextOptions opts, Status status) | ||||
{ | { | ||||
Handle = c_api.TFE_NewContext(opts, status.Handle); | |||||
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
@@ -1,26 +1,33 @@ | |||||
using System; | |||||
using System.IO; | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
public class ContextOptions : DisposableObject | |||||
public sealed class ContextOptions : IDisposable | |||||
{ | { | ||||
public ContextOptions() : base(c_api.TFE_NewContextOptions()) | |||||
{ } | |||||
public SafeContextOptionsHandle Handle { get; } | |||||
/// <summary> | |||||
/// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
/// </summary> | |||||
protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
=> c_api.TFE_DeleteContextOptions(_handle); | |||||
public ContextOptions() | |||||
{ | |||||
Handle = c_api.TFE_NewContextOptions(); | |||||
} | |||||
public static implicit operator IntPtr(ContextOptions opts) | |||||
=> opts._handle; | |||||
public static implicit operator TFE_ContextOptions(ContextOptions opts) | |||||
=> new TFE_ContextOptions(opts._handle); | |||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,40 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public sealed class SafeContextOptionsHandle : SafeTensorflowHandle | |||||
{ | |||||
public SafeContextOptionsHandle() | |||||
{ | |||||
} | |||||
public SafeContextOptionsHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TFE_DeleteContextOptions(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -1,23 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public struct TFE_ContextOptions | |||||
{ | |||||
IntPtr _handle; | |||||
public TFE_ContextOptions(IntPtr handle) | |||||
=> _handle = handle; | |||||
public static implicit operator TFE_ContextOptions(IntPtr handle) | |||||
=> new TFE_ContextOptions(handle); | |||||
public static implicit operator IntPtr(TFE_ContextOptions tensor) | |||||
=> tensor._handle; | |||||
public override string ToString() | |||||
=> $"TFE_ContextOptions {_handle}"; | |||||
} | |||||
} |
@@ -52,7 +52,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <returns>TFE_ContextOptions*</returns> | /// <returns>TFE_ContextOptions*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TFE_ContextOptions TFE_NewContextOptions(); | |||||
public static extern SafeContextOptionsHandle TFE_NewContextOptions(); | |||||
/// <summary> | /// <summary> | ||||
/// Destroy an options object. | /// Destroy an options object. | ||||
@@ -114,7 +114,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns>TFE_Context*</returns> | /// <returns>TFE_Context*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||||
public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | ||||
@@ -254,7 +254,7 @@ namespace Tensorflow | |||||
/// <param name="opts">TFE_ContextOptions*</param> | /// <param name="opts">TFE_ContextOptions*</param> | ||||
/// <param name="enable">unsigned char</param> | /// <param name="enable">unsigned char</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextOptionsSetAsync(IntPtr opts, byte enable); | |||||
public static extern void TFE_ContextOptionsSetAsync(SafeContextOptionsHandle opts, byte enable); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -102,15 +102,12 @@ namespace TensorFlowNET.UnitTest | |||||
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | ||||
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status); | => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | ||||
protected IntPtr TFE_NewContextOptions() | |||||
protected SafeContextOptionsHandle TFE_NewContextOptions() | |||||
=> c_api.TFE_NewContextOptions(); | => c_api.TFE_NewContextOptions(); | ||||
protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||||
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | |||||
=> c_api.TFE_NewContext(opts, status); | => c_api.TFE_NewContext(opts, status); | ||||
protected void TFE_DeleteContextOptions(IntPtr opts) | |||||
=> c_api.TFE_DeleteContextOptions(opts); | |||||
protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | ||||
=> c_api.TFE_OpGetInputLength(op, input_name, status); | => c_api.TFE_OpGetInputLength(op, input_name, status); | ||||
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | |||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
{ | { | ||||
@@ -13,13 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
public void Context() | public void Context() | ||||
{ | { | ||||
using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
var opts = c_api.TFE_NewContextOptions(); | |||||
IntPtr devices; | |||||
using (var ctx = c_api.TFE_NewContext(opts, status)) | |||||
static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
{ | { | ||||
c_api.TFE_DeleteContextOptions(opts); | |||||
using var opts = c_api.TFE_NewContextOptions(); | |||||
return c_api.TFE_NewContext(opts, status); | |||||
} | |||||
IntPtr devices; | |||||
using (var ctx = NewContext(status)) | |||||
{ | |||||
devices = c_api.TFE_ContextListDevices(ctx, status); | devices = c_api.TFE_ContextListDevices(ctx, status); | ||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
} | } | ||||
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
@@ -19,14 +20,18 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
unsafe void Execute_MatMul_CPU(bool async) | unsafe void Execute_MatMul_CPU(bool async) | ||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | |||||
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | |||||
static SafeContextHandle NewContext(bool async, SafeStatusHandle status) | |||||
{ | |||||
using var opts = c_api.TFE_NewContextOptions(); | |||||
c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | |||||
return c_api.TFE_NewContext(opts, status); | |||||
} | |||||
IntPtr t; | IntPtr t; | ||||
using (var ctx = TFE_NewContext(opts, status)) | |||||
using (var ctx = NewContext(async, status)) | |||||
{ | { | ||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | |||||
var m = TestMatrixTensorHandle(); | var m = TestMatrixTensorHandle(); | ||||
var matmul = MatMulOp(ctx, m, m); | var matmul = MatMulOp(ctx, m, m); | ||||
@@ -2,7 +2,6 @@ | |||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Buffer = System.Buffer; | |||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
{ | { | ||||
@@ -15,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
public unsafe void OpGetInputAndOutputLengths() | public unsafe void OpGetInputAndOutputLengths() | ||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
{ | |||||
using var opts = c_api.TFE_NewContextOptions(); | |||||
return c_api.TFE_NewContext(opts, status); | |||||
} | |||||
using var ctx = NewContext(status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | |||||
var input1 = TestMatrixTensorHandle(); | var input1 = TestMatrixTensorHandle(); | ||||
var input2 = TestMatrixTensorHandle(); | var input2 = TestMatrixTensorHandle(); | ||||
@@ -1,10 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Buffer = System.Buffer; | |||||
using System.Linq; | |||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
{ | { | ||||
@@ -17,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
public unsafe void OpInferMixedTypeInputListAttrs() | public unsafe void OpInferMixedTypeInputListAttrs() | ||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
{ | |||||
using var opts = c_api.TFE_NewContextOptions(); | |||||
return c_api.TFE_NewContext(opts, status); | |||||
} | |||||
using var ctx = NewContext(status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | |||||
var condition = TestScalarTensorHandle(true); | var condition = TestScalarTensorHandle(true); | ||||
var t1 = TestMatrixTensorHandle(); | var t1 = TestMatrixTensorHandle(); | ||||
@@ -2,7 +2,6 @@ | |||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Buffer = System.Buffer; | |||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
{ | { | ||||
@@ -15,9 +14,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
public unsafe void TensorHandleDevices() | public unsafe void TensorHandleDevices() | ||||
{ | { | ||||
var status = c_api.TF_NewStatus(); | var status = c_api.TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
TFE_DeleteContextOptions(opts); | |||||
static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
{ | |||||
using var opts = c_api.TFE_NewContextOptions(); | |||||
return c_api.TFE_NewContext(opts, status); | |||||
} | |||||
using var ctx = NewContext(status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
var hcpu = TestMatrixTensorHandle(); | var hcpu = TestMatrixTensorHandle(); | ||||
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
@@ -14,10 +15,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
public unsafe void Variables() | public unsafe void Variables() | ||||
{ | { | ||||
using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
var opts = TFE_NewContextOptions(); | |||||
using var ctx = TFE_NewContext(opts, status); | |||||
static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
{ | |||||
using var opts = c_api.TFE_NewContextOptions(); | |||||
return c_api.TFE_NewContext(opts, status); | |||||
} | |||||
using var ctx = NewContext(status); | |||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
TFE_DeleteContextOptions(opts); | |||||
var var_handle = CreateVariable(ctx, 12.0f, status); | var var_handle = CreateVariable(ctx, 12.0f, status); | ||||
ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||